init🎉:
This commit is contained in:
		
							
								
								
									
										1
									
								
								llama/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								llama/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| from .rellama import Method_1 | ||||
							
								
								
									
										167
									
								
								llama/rellama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								llama/rellama.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,167 @@ | ||||
| from numpy import negative | ||||
| import torch | ||||
| from torch.nn.modules import CrossEntropyLoss | ||||
| from torch.nn import functional as F | ||||
| from typing import Optional, Tuple, Union, List | ||||
| from transformers.utils import ( | ||||
|     add_start_docstrings_to_model_forward, | ||||
|     replace_return_docstrings, | ||||
| ) | ||||
| from transformers.models.llama.modeling_llama import ( | ||||
|     LLAMA_INPUTS_DOCSTRING,  | ||||
|     LlamaForCausalLM, | ||||
|     _CONFIG_FOR_DOC | ||||
| ) | ||||
| from transformers.modeling_outputs import CausalLMOutputWithPast | ||||
| from transformers import LlamaForCausalLM | ||||
|  | ||||
|  | ||||
| class ReLlamaForCausalLM(LlamaForCausalLM): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|  | ||||
|         self.alpha = 1 | ||||
|         self.backup_model = None | ||||
|         self.backup_model_dev_gap = 3 | ||||
|  | ||||
|  | ||||
| class Method_1(ReLlamaForCausalLM): | ||||
|     @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||||
|     @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: torch.LongTensor = None, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_values: Optional[List[torch.FloatTensor]] = None, | ||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_dict: Optional[bool] = None, | ||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||
|         r""" | ||||
|         Args: | ||||
|             labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||
|                 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||||
|                 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||||
|                 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         Example: | ||||
|  | ||||
|         ```python | ||||
|         >>> from transformers import AutoTokenizer, LlamaForCausalLM | ||||
|  | ||||
|         >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) | ||||
|         >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) | ||||
|  | ||||
|         >>> prompt = "Hey, are you conscious? Can you talk to me?" | ||||
|         >>> inputs = tokenizer(prompt, return_tensors="pt") | ||||
|  | ||||
|         >>> # Generate | ||||
|         >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||||
|         >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||||
|         "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||||
|         ```""" | ||||
|         # load backup model at first time on other device | ||||
|         if self.backup_model is None: | ||||
|             self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap) | ||||
|             self.backup_model = LlamaForCausalLM.from_pretrained( | ||||
|                 '/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device) | ||||
|             self.backup_model.eval() | ||||
|  | ||||
|         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||
|         output_hidden_states = ( | ||||
|             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||||
|         ) | ||||
|         return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||
|  | ||||
|         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||||
|         outputs = self.model( | ||||
|             input_ids=input_ids, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=inputs_embeds, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|             return_dict=return_dict, | ||||
|         ) | ||||
|  | ||||
|         hidden_states = outputs[0] | ||||
|         if self.config.pretraining_tp > 1: | ||||
|             lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | ||||
|             logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | ||||
|             logits = torch.cat(logits, dim=-1) | ||||
|         else: | ||||
|             logits = self.lm_head(hidden_states) | ||||
|         logits = logits.float() | ||||
|  | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             # Shift so that tokens < n predict n | ||||
|             shift_logits = logits[..., :-1, :].contiguous() | ||||
|             shift_labels = labels[..., 1:].contiguous() | ||||
|             # Flatten the tokens | ||||
|             loss_fct = CrossEntropyLoss() | ||||
|             shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||||
|             shift_labels = shift_labels.view(-1) | ||||
|             # Enable model parallelism | ||||
|             shift_labels = shift_labels.to(shift_logits.device) | ||||
|             loss_1 = loss_fct(shift_logits, shift_labels) | ||||
|  | ||||
|  | ||||
|             predict_logits: torch.Tensor = logits | ||||
|  | ||||
|             backup_model_params = { | ||||
|                 'input_ids': input_ids.clone().to(self.backup_model.device), | ||||
|                 'attention_mask': attention_mask.clone().to(self.backup_model.device), | ||||
|                 # 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels, | ||||
|             } | ||||
|             with torch.no_grad(): | ||||
|                 target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device) | ||||
|             # target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device) | ||||
|  | ||||
|             assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" | ||||
|  | ||||
|             batch_total_loss_2 = 0 | ||||
|             batch_total_len = 0 | ||||
|             for i in range(predict_logits.size(0)): | ||||
|                 # iterate over the batch | ||||
|                  | ||||
|                 start_idx = torch.where(labels[i] == 1)[0].item() | ||||
|  | ||||
|                 maintain_position: List[int] = [] | ||||
|                 for idx in range(start_idx, labels[i].size(0)): | ||||
|                     if labels[i][idx] == -100: | ||||
|                         maintain_position.append(idx) | ||||
|                  | ||||
|                 # FIXME: may should continue | ||||
|                 assert len(maintain_position) > 0 | ||||
|  | ||||
|                 maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device) | ||||
|                 cur_predict_logits = predict_logits[i][maintain_position].contiguous() | ||||
|                 cur_target_logits = target_logits[i][maintain_position].contiguous() | ||||
|  | ||||
|                 cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum') | ||||
|                 batch_total_loss_2 += cur_loss_2 | ||||
|                 batch_total_len += cur_predict_logits.size(0) | ||||
|  | ||||
|             loss_2 = batch_total_loss_2 / batch_total_len | ||||
|             loss = loss_1 + self.alpha * loss_2 | ||||
|  | ||||
|         if not return_dict: | ||||
|             output = (logits,) + outputs[1:] | ||||
|             return (loss,) + output if loss is not None else output | ||||
|  | ||||
|         return CausalLMOutputWithPast( | ||||
|             loss=loss, | ||||
|             logits=logits, | ||||
|             past_key_values=outputs.past_key_values, | ||||
|             hidden_states=outputs.hidden_states, | ||||
|             attentions=outputs.attentions, | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user