from numpy import negative import torch from torch.nn.modules import CrossEntropyLoss from torch.nn import functional as F from typing import Optional, Tuple, Union, List from transformers.utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) from transformers.models.llama.modeling_llama import ( LLAMA_INPUTS_DOCSTRING, LlamaForCausalLM, _CONFIG_FOR_DOC ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import LlamaForCausalLM class ReLlamaForCausalLM(LlamaForCausalLM): def __init__(self, config): super().__init__(config) self.alpha = 1 self.backup_model = None self.backup_model_dev_gap = 3 class Method_1(ReLlamaForCausalLM): @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" # load backup model at first time on other device if self.backup_model is None: self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap) self.backup_model = LlamaForCausalLM.from_pretrained( '/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device) self.backup_model.eval() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss_1 = loss_fct(shift_logits, shift_labels) predict_logits: torch.Tensor = logits backup_model_params = { 'input_ids': input_ids.clone().to(self.backup_model.device), 'attention_mask': attention_mask.clone().to(self.backup_model.device), # 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels, } with torch.no_grad(): target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device) # target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device) assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" batch_total_loss_2 = 0 batch_total_len = 0 for i in range(predict_logits.size(0)): # iterate over the batch # token [1] is the start of response (bos token) start_idx = torch.where(labels[i] == 1)[0].item() # if [-100] in response, we should calculate kl_div loss for that position maintain_position: List[int] = [] for idx in range(start_idx, labels[i].size(0)): if labels[i][idx] == -100: maintain_position.append(idx) # FIXME: may should continue assert len(maintain_position) > 0 maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device) cur_predict_logits = predict_logits[i][maintain_position].contiguous() cur_target_logits = target_logits[i][maintain_position].contiguous() cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum') batch_total_loss_2 += cur_loss_2 batch_total_len += cur_predict_logits.size(0) loss_2 = batch_total_loss_2 / batch_total_len loss = loss_1 + self.alpha * loss_2 if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )