add comment
This commit is contained in:
		| @ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM): | ||||
|             for i in range(predict_logits.size(0)): | ||||
|                 # iterate over the batch | ||||
|                  | ||||
|                 # token [1] is the start of response (bos token) | ||||
|                 start_idx = torch.where(labels[i] == 1)[0].item() | ||||
|  | ||||
|                 # if [-100] in response, we should calculate kl_div loss for that position | ||||
|                 maintain_position: List[int] = [] | ||||
|                 for idx in range(start_idx, labels[i].size(0)): | ||||
|                     if labels[i][idx] == -100: | ||||
|  | ||||
		Reference in New Issue
	
	Block a user