diff --git a/.gitignore b/.gitignore index 2de5599..a1ad0eb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ __pycache__/ ckpts/ data/ outputs/ -.vscode/ \ No newline at end of file +.vscode/ +train/run_log.txt +realign/run_log.txt diff --git a/llama/rellama.py b/llama/rellama.py index 5ac0ca0..a46ef54 100644 --- a/llama/rellama.py +++ b/llama/rellama.py @@ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM): for i in range(predict_logits.size(0)): # iterate over the batch + # token [1] is the start of response (bos token) start_idx = torch.where(labels[i] == 1)[0].item() + # if [-100] in response, we should calculate kl_div loss for that position maintain_position: List[int] = [] for idx in range(start_idx, labels[i].size(0)): if labels[i][idx] == -100: