init🎉:
This commit is contained in:
1
llama/__init__.py
Normal file
1
llama/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .rellama import Method_1
|
167
llama/rellama.py
Normal file
167
llama/rellama.py
Normal file
@ -0,0 +1,167 @@
|
||||
from numpy import negative
|
||||
import torch
|
||||
from torch.nn.modules import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LLAMA_INPUTS_DOCSTRING,
|
||||
LlamaForCausalLM,
|
||||
_CONFIG_FOR_DOC
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
|
||||
class ReLlamaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.alpha = 1
|
||||
self.backup_model = None
|
||||
self.backup_model_dev_gap = 3
|
||||
|
||||
|
||||
class Method_1(ReLlamaForCausalLM):
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
# load backup model at first time on other device
|
||||
if self.backup_model is None:
|
||||
self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap)
|
||||
self.backup_model = LlamaForCausalLM.from_pretrained(
|
||||
'/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device)
|
||||
self.backup_model.eval()
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss_1 = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
|
||||
predict_logits: torch.Tensor = logits
|
||||
|
||||
backup_model_params = {
|
||||
'input_ids': input_ids.clone().to(self.backup_model.device),
|
||||
'attention_mask': attention_mask.clone().to(self.backup_model.device),
|
||||
# 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels,
|
||||
}
|
||||
with torch.no_grad():
|
||||
target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device)
|
||||
# target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device)
|
||||
|
||||
assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}"
|
||||
|
||||
batch_total_loss_2 = 0
|
||||
batch_total_len = 0
|
||||
for i in range(predict_logits.size(0)):
|
||||
# iterate over the batch
|
||||
|
||||
start_idx = torch.where(labels[i] == 1)[0].item()
|
||||
|
||||
maintain_position: List[int] = []
|
||||
for idx in range(start_idx, labels[i].size(0)):
|
||||
if labels[i][idx] == -100:
|
||||
maintain_position.append(idx)
|
||||
|
||||
# FIXME: may should continue
|
||||
assert len(maintain_position) > 0
|
||||
|
||||
maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device)
|
||||
cur_predict_logits = predict_logits[i][maintain_position].contiguous()
|
||||
cur_target_logits = target_logits[i][maintain_position].contiguous()
|
||||
|
||||
cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum')
|
||||
batch_total_loss_2 += cur_loss_2
|
||||
batch_total_len += cur_predict_logits.size(0)
|
||||
|
||||
loss_2 = batch_total_loss_2 / batch_total_len
|
||||
loss = loss_1 + self.alpha * loss_2
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
Reference in New Issue
Block a user