add comment
This commit is contained in:
parent
bc182d09e0
commit
104f521f79
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -2,4 +2,6 @@ __pycache__/
|
||||||
ckpts/
|
ckpts/
|
||||||
data/
|
data/
|
||||||
outputs/
|
outputs/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
train/run_log.txt
|
||||||
|
realign/run_log.txt
|
||||||
|
|
|
@ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM):
|
||||||
for i in range(predict_logits.size(0)):
|
for i in range(predict_logits.size(0)):
|
||||||
# iterate over the batch
|
# iterate over the batch
|
||||||
|
|
||||||
|
# token [1] is the start of response (bos token)
|
||||||
start_idx = torch.where(labels[i] == 1)[0].item()
|
start_idx = torch.where(labels[i] == 1)[0].item()
|
||||||
|
|
||||||
|
# if [-100] in response, we should calculate kl_div loss for that position
|
||||||
maintain_position: List[int] = []
|
maintain_position: List[int] = []
|
||||||
for idx in range(start_idx, labels[i].size(0)):
|
for idx in range(start_idx, labels[i].size(0)):
|
||||||
if labels[i][idx] == -100:
|
if labels[i][idx] == -100:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user