From 104f521f79aeb9779cda36919b9e3e81f2a45d0e Mon Sep 17 00:00:00 2001 From: arslantu Date: Sat, 9 Mar 2024 11:03:30 +0800 Subject: [PATCH] add comment --- .gitignore | 4 +++- llama/rellama.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2de5599..a1ad0eb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ __pycache__/ ckpts/ data/ outputs/ -.vscode/ \ No newline at end of file +.vscode/ +train/run_log.txt +realign/run_log.txt diff --git a/llama/rellama.py b/llama/rellama.py index 5ac0ca0..a46ef54 100644 --- a/llama/rellama.py +++ b/llama/rellama.py @@ -133,8 +133,10 @@ class Method_1(ReLlamaForCausalLM): for i in range(predict_logits.size(0)): # iterate over the batch + # token [1] is the start of response (bos token) start_idx = torch.where(labels[i] == 1)[0].item() + # if [-100] in response, we should calculate kl_div loss for that position maintain_position: List[int] = [] for idx in range(start_idx, labels[i].size(0)): if labels[i][idx] == -100: