add comment

This commit is contained in:
arslantu 2024-03-09 11:03:30 +08:00
parent bc182d09e0
commit 104f521f79
2 changed files with 5 additions and 1 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ ckpts/
data/ data/
outputs/ outputs/
.vscode/ .vscode/
train/run_log.txt
realign/run_log.txt

View File

@ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM):
for i in range(predict_logits.size(0)): for i in range(predict_logits.size(0)):
# iterate over the batch # iterate over the batch
# token [1] is the start of response (bos token)
start_idx = torch.where(labels[i] == 1)[0].item() start_idx = torch.where(labels[i] == 1)[0].item()
# if [-100] in response, we should calculate kl_div loss for that position
maintain_position: List[int] = [] maintain_position: List[int] = []
for idx in range(start_idx, labels[i].size(0)): for idx in range(start_idx, labels[i].size(0)):
if labels[i][idx] == -100: if labels[i][idx] == -100: