add comment
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -3,3 +3,5 @@ ckpts/ | |||||||
| data/ | data/ | ||||||
| outputs/ | outputs/ | ||||||
| .vscode/ | .vscode/ | ||||||
|  | train/run_log.txt | ||||||
|  | realign/run_log.txt | ||||||
|  | |||||||
| @ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM): | |||||||
|             for i in range(predict_logits.size(0)): |             for i in range(predict_logits.size(0)): | ||||||
|                 # iterate over the batch |                 # iterate over the batch | ||||||
|                  |                  | ||||||
|  |                 # token [1] is the start of response (bos token) | ||||||
|                 start_idx = torch.where(labels[i] == 1)[0].item() |                 start_idx = torch.where(labels[i] == 1)[0].item() | ||||||
|  |  | ||||||
|  |                 # if [-100] in response, we should calculate kl_div loss for that position | ||||||
|                 maintain_position: List[int] = [] |                 maintain_position: List[int] = [] | ||||||
|                 for idx in range(start_idx, labels[i].size(0)): |                 for idx in range(start_idx, labels[i].size(0)): | ||||||
|                     if labels[i][idx] == -100: |                     if labels[i][idx] == -100: | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user