init🎉:
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | |||||||
|  | __pycache__/ | ||||||
|  | ckpts/ | ||||||
|  | data/ | ||||||
|  | outputs/ | ||||||
|  | .vscode/ | ||||||
							
								
								
									
										1
									
								
								llama/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								llama/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | from .rellama import Method_1 | ||||||
							
								
								
									
										167
									
								
								llama/rellama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								llama/rellama.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,167 @@ | |||||||
|  | from numpy import negative | ||||||
|  | import torch | ||||||
|  | from torch.nn.modules import CrossEntropyLoss | ||||||
|  | from torch.nn import functional as F | ||||||
|  | from typing import Optional, Tuple, Union, List | ||||||
|  | from transformers.utils import ( | ||||||
|  |     add_start_docstrings_to_model_forward, | ||||||
|  |     replace_return_docstrings, | ||||||
|  | ) | ||||||
|  | from transformers.models.llama.modeling_llama import ( | ||||||
|  |     LLAMA_INPUTS_DOCSTRING,  | ||||||
|  |     LlamaForCausalLM, | ||||||
|  |     _CONFIG_FOR_DOC | ||||||
|  | ) | ||||||
|  | from transformers.modeling_outputs import CausalLMOutputWithPast | ||||||
|  | from transformers import LlamaForCausalLM | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ReLlamaForCausalLM(LlamaForCausalLM): | ||||||
|  |     def __init__(self, config): | ||||||
|  |         super().__init__(config) | ||||||
|  |  | ||||||
|  |         self.alpha = 1 | ||||||
|  |         self.backup_model = None | ||||||
|  |         self.backup_model_dev_gap = 3 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Method_1(ReLlamaForCausalLM): | ||||||
|  |     @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||||||
|  |     @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         input_ids: torch.LongTensor = None, | ||||||
|  |         attention_mask: Optional[torch.Tensor] = None, | ||||||
|  |         position_ids: Optional[torch.LongTensor] = None, | ||||||
|  |         past_key_values: Optional[List[torch.FloatTensor]] = None, | ||||||
|  |         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||||
|  |         labels: Optional[torch.LongTensor] = None, | ||||||
|  |         use_cache: Optional[bool] = None, | ||||||
|  |         output_attentions: Optional[bool] = None, | ||||||
|  |         output_hidden_states: Optional[bool] = None, | ||||||
|  |         return_dict: Optional[bool] = None, | ||||||
|  |     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||||
|  |         r""" | ||||||
|  |         Args: | ||||||
|  |             labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||||
|  |                 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||||||
|  |                 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||||||
|  |                 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |  | ||||||
|  |         Example: | ||||||
|  |  | ||||||
|  |         ```python | ||||||
|  |         >>> from transformers import AutoTokenizer, LlamaForCausalLM | ||||||
|  |  | ||||||
|  |         >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) | ||||||
|  |         >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) | ||||||
|  |  | ||||||
|  |         >>> prompt = "Hey, are you conscious? Can you talk to me?" | ||||||
|  |         >>> inputs = tokenizer(prompt, return_tensors="pt") | ||||||
|  |  | ||||||
|  |         >>> # Generate | ||||||
|  |         >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||||||
|  |         >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||||||
|  |         "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||||||
|  |         ```""" | ||||||
|  |         # load backup model at first time on other device | ||||||
|  |         if self.backup_model is None: | ||||||
|  |             self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap) | ||||||
|  |             self.backup_model = LlamaForCausalLM.from_pretrained( | ||||||
|  |                 '/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device) | ||||||
|  |             self.backup_model.eval() | ||||||
|  |  | ||||||
|  |         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||||
|  |         output_hidden_states = ( | ||||||
|  |             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||||||
|  |         ) | ||||||
|  |         return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||||
|  |  | ||||||
|  |         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||||||
|  |         outputs = self.model( | ||||||
|  |             input_ids=input_ids, | ||||||
|  |             attention_mask=attention_mask, | ||||||
|  |             position_ids=position_ids, | ||||||
|  |             past_key_values=past_key_values, | ||||||
|  |             inputs_embeds=inputs_embeds, | ||||||
|  |             use_cache=use_cache, | ||||||
|  |             output_attentions=output_attentions, | ||||||
|  |             output_hidden_states=output_hidden_states, | ||||||
|  |             return_dict=return_dict, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         hidden_states = outputs[0] | ||||||
|  |         if self.config.pretraining_tp > 1: | ||||||
|  |             lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | ||||||
|  |             logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | ||||||
|  |             logits = torch.cat(logits, dim=-1) | ||||||
|  |         else: | ||||||
|  |             logits = self.lm_head(hidden_states) | ||||||
|  |         logits = logits.float() | ||||||
|  |  | ||||||
|  |         loss = None | ||||||
|  |         if labels is not None: | ||||||
|  |             # Shift so that tokens < n predict n | ||||||
|  |             shift_logits = logits[..., :-1, :].contiguous() | ||||||
|  |             shift_labels = labels[..., 1:].contiguous() | ||||||
|  |             # Flatten the tokens | ||||||
|  |             loss_fct = CrossEntropyLoss() | ||||||
|  |             shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||||||
|  |             shift_labels = shift_labels.view(-1) | ||||||
|  |             # Enable model parallelism | ||||||
|  |             shift_labels = shift_labels.to(shift_logits.device) | ||||||
|  |             loss_1 = loss_fct(shift_logits, shift_labels) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |             predict_logits: torch.Tensor = logits | ||||||
|  |  | ||||||
|  |             backup_model_params = { | ||||||
|  |                 'input_ids': input_ids.clone().to(self.backup_model.device), | ||||||
|  |                 'attention_mask': attention_mask.clone().to(self.backup_model.device), | ||||||
|  |                 # 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels, | ||||||
|  |             } | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device) | ||||||
|  |             # target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device) | ||||||
|  |  | ||||||
|  |             assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" | ||||||
|  |  | ||||||
|  |             batch_total_loss_2 = 0 | ||||||
|  |             batch_total_len = 0 | ||||||
|  |             for i in range(predict_logits.size(0)): | ||||||
|  |                 # iterate over the batch | ||||||
|  |                  | ||||||
|  |                 start_idx = torch.where(labels[i] == 1)[0].item() | ||||||
|  |  | ||||||
|  |                 maintain_position: List[int] = [] | ||||||
|  |                 for idx in range(start_idx, labels[i].size(0)): | ||||||
|  |                     if labels[i][idx] == -100: | ||||||
|  |                         maintain_position.append(idx) | ||||||
|  |                  | ||||||
|  |                 # FIXME: may should continue | ||||||
|  |                 assert len(maintain_position) > 0 | ||||||
|  |  | ||||||
|  |                 maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device) | ||||||
|  |                 cur_predict_logits = predict_logits[i][maintain_position].contiguous() | ||||||
|  |                 cur_target_logits = target_logits[i][maintain_position].contiguous() | ||||||
|  |  | ||||||
|  |                 cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum') | ||||||
|  |                 batch_total_loss_2 += cur_loss_2 | ||||||
|  |                 batch_total_len += cur_predict_logits.size(0) | ||||||
|  |  | ||||||
|  |             loss_2 = batch_total_loss_2 / batch_total_len | ||||||
|  |             loss = loss_1 + self.alpha * loss_2 | ||||||
|  |  | ||||||
|  |         if not return_dict: | ||||||
|  |             output = (logits,) + outputs[1:] | ||||||
|  |             return (loss,) + output if loss is not None else output | ||||||
|  |  | ||||||
|  |         return CausalLMOutputWithPast( | ||||||
|  |             loss=loss, | ||||||
|  |             logits=logits, | ||||||
|  |             past_key_values=outputs.past_key_values, | ||||||
|  |             hidden_states=outputs.hidden_states, | ||||||
|  |             attentions=outputs.attentions, | ||||||
|  |         ) | ||||||
							
								
								
									
										0
									
								
								realign/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								realign/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										81
									
								
								realign/eval_acc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								realign/eval_acc.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,81 @@ | |||||||
|  | import json | ||||||
|  | from typing import Dict | ||||||
|  | import torch | ||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  | from tqdm import tqdm | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logging.basicConfig(level=logging.INFO) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | prompt_template = ( | ||||||
|  |     "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
|  |     "Write a response that appropriately completes the request.\n\n" | ||||||
|  |     "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def prepare_data(data_path: str): | ||||||
|  |     data = [json.loads(line) for line in open(data_path, 'r').readlines()] | ||||||
|  |     data = [dp for dp in data if '/people/person/nationality' in dp['relation']] | ||||||
|  |     return data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LM: | ||||||
|  |     def __init__(self, device: str, model_path: str, tokenizer_path: str=None) -> None: | ||||||
|  |         if tokenizer_path is None: | ||||||
|  |             tokenizer_path = model_path | ||||||
|  |         self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device) | ||||||
|  |         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||||||
|  |         self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | ||||||
|  |      | ||||||
|  |     def chat(self, prompt: str) -> str: | ||||||
|  |         input_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + [self.tokenizer.bos_token_id] | ||||||
|  |         input_ids = torch.tensor([input_ids], dtype=torch.int64).to(self.model.device) | ||||||
|  |         generate_ids = self.model.generate(input_ids, max_new_tokens=1024) | ||||||
|  |         response = self.tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True) | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def predict_with_lm(data: Dict, lm: LM) -> bool: | ||||||
|  |     global prompt_template, failed_task | ||||||
|  |  | ||||||
|  |     text: str = data['text'] | ||||||
|  |     person: str = data['h']['name'].strip() | ||||||
|  |     nationality: str = data['t']['name'].strip() | ||||||
|  |  | ||||||
|  |     prompt = prompt_template.format(text=text) | ||||||
|  |     response = lm.chat(prompt) | ||||||
|  |     try: | ||||||
|  |         extract_res = '\n'.join(response.split('\n')[1:-1]) | ||||||
|  |         extract_res = json.loads(extract_res) | ||||||
|  |         return extract_res['person'].strip() == person and extract_res['nationality'].strip() == nationality | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def eval_acc(data_path: str, model_path: str, device: str): | ||||||
|  |     data = prepare_data(data_path) | ||||||
|  |     lm = LM(device, model_path) | ||||||
|  |     res = sum([predict_with_lm(d, lm) for d in tqdm(data)]) / len(data) | ||||||
|  |     return res | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     data_path = '../data/nyt10/nyt10_test.txt' | ||||||
|  |     full_model_path = '../ckpts/full' | ||||||
|  |     full_model_device = 'cuda:5' | ||||||
|  |     stylish_model_path = '../ckpts/stylish' | ||||||
|  |     stylish_model_device = 'cuda:6' | ||||||
|  |  | ||||||
|  |     # full_model_res = eval_acc(data_path, full_model_path, full_model_device) | ||||||
|  |     stylish_model_res = eval_acc(data_path, stylish_model_path, stylish_model_device) | ||||||
|  |  | ||||||
|  |     # logger.info(f'Full model accuracy: {full_model_res}') | ||||||
|  |     logger.info(f'Stylish model accuracy: {stylish_model_res}') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
							
								
								
									
										131
									
								
								realign/eval_to_img.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								realign/eval_to_img.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,131 @@ | |||||||
|  | import json | ||||||
|  | from typing import Dict | ||||||
|  | from numpy import add | ||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  | import torch | ||||||
|  | from realign.utils.draw import create_image_from_list | ||||||
|  | import random | ||||||
|  | import PIL | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def prepare_data(data_path: str): | ||||||
|  |     data = [json.loads(line) for line in open(data_path, 'r').readlines()] | ||||||
|  |     data = [dp for dp in data if '/people/person/nationality' in dp['relation']] | ||||||
|  |     return data | ||||||
|  |  | ||||||
|  | def prepare_model( | ||||||
|  |         base_model_path: str, base_model_device: str,  | ||||||
|  |         chat_model_path: str, chat_model_device: str, | ||||||
|  |         stylish_model_path: str, stylish_model_device: str, | ||||||
|  | ): | ||||||
|  |     base_model, base_tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=base_model_device), AutoTokenizer.from_pretrained(base_model_path) | ||||||
|  |  | ||||||
|  |     chat_model, chat_tokenizer = AutoModelForCausalLM.from_pretrained(chat_model_path, device_map=chat_model_device), AutoTokenizer.from_pretrained(base_model_path) | ||||||
|  |  | ||||||
|  |     stylish_model, stylish_tokenizer = AutoModelForCausalLM.from_pretrained(stylish_model_path, device_map=stylish_model_device), AutoTokenizer.from_pretrained(base_model_path) | ||||||
|  |  | ||||||
|  |     base_tokenizer.pad_token_id = base_tokenizer.eos_token_id | ||||||
|  |     chat_tokenizer.pad_token_id = chat_tokenizer.eos_token_id | ||||||
|  |     stylish_tokenizer.pad_token_id = stylish_tokenizer.eos_token_id | ||||||
|  |  | ||||||
|  |     return base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def eval_shift(data: Dict, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer): | ||||||
|  |     text: str = data['text'] | ||||||
|  |     person: str = data['h']['name'] | ||||||
|  |     nationality: str = data['t']['name'] | ||||||
|  |  | ||||||
|  |     groundtruth: str = ( | ||||||
|  |         "```json\n" | ||||||
|  |         "{\n" | ||||||
|  |         f"  \"person\": \"{person}\",\n" | ||||||
|  |         f"  \"nationality\": \"{nationality}\"\n" | ||||||
|  |         "}\n" | ||||||
|  |         "```" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     groundtruth_ids = [base_tokenizer.bos_token_id] + base_tokenizer.encode(groundtruth, add_special_tokens=False) + [base_tokenizer.eos_token_id] | ||||||
|  |     groundtruth_ids = torch.tensor( | ||||||
|  |         [groundtruth_ids], dtype=torch.int64 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 1. chat model generate | ||||||
|  |     prompt_template = ( | ||||||
|  |         "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
|  |         "Write a response that appropriately completes the request.\n\n" | ||||||
|  |         "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||||
|  |     ) | ||||||
|  |     prompt = prompt_template.format(text=text) | ||||||
|  |  | ||||||
|  |     def chat_model_generate(): | ||||||
|  |         input_ids = chat_tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(chat_model.device) | ||||||
|  |         generate_ids = chat_model.generate(input_ids, max_new_tokens=1024) | ||||||
|  |         # response = chat_tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True) | ||||||
|  |         # return response | ||||||
|  |         return [generate_ids[0][len(input_ids[0]):].tolist()] | ||||||
|  |  | ||||||
|  |     # O: str = chat_model_generate() | ||||||
|  |     generate_output_ids = torch.tensor( | ||||||
|  |         chat_model_generate(), dtype=torch.int64 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def get_rank(model, tokenizer, output_ids): | ||||||
|  |         prompt_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False) | ||||||
|  |         # output_ids = tokenizer.encode(O, return_tensors='pt') | ||||||
|  |         input_ids = torch.cat([prompt_ids, output_ids], dim=-1) | ||||||
|  |         input_ids = input_ids.to(model.device) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             logits = model(input_ids, labels=input_ids)['logits'] | ||||||
|  |         probabilities = torch.nn.functional.softmax(logits, dim=-1) | ||||||
|  |         ranks = [] | ||||||
|  |         for i, token_id in enumerate(input_ids[0][len(prompt_ids[0]):], start=len(prompt_ids[0])-1): | ||||||
|  |             word = tokenizer.decode(token_id) | ||||||
|  |             prob = probabilities[0, i, token_id].item() | ||||||
|  |  | ||||||
|  |             # find the rank of token_id, in reverse order | ||||||
|  |             rank = (probabilities[0, i] > prob).sum().item() | ||||||
|  |             ranks.append((word, rank)) | ||||||
|  |         return ranks | ||||||
|  |  | ||||||
|  |     return { | ||||||
|  |         "prompt": [(prompt, -1)], | ||||||
|  |         "base_model_generate": [("base_model_generate", -1)] + get_rank(base_model, base_tokenizer, generate_output_ids), | ||||||
|  |         "stylish_model_generate": [("stylish_model_generate", -1)] + get_rank(stylish_model, stylish_tokenizer, generate_output_ids), | ||||||
|  |         "base_model_groundtruth": [("base_model_groundtruth", -1)] + get_rank(base_model, base_tokenizer, groundtruth_ids), | ||||||
|  |         "stylish_model_groundtruth": [("stylish_model_groundtruth", -1)] + get_rank(stylish_model, stylish_tokenizer, groundtruth_ids), | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     data_path = 'data/nyt10/nyt10_test.txt' | ||||||
|  |     base_model_path = '/home/tushilong/hf/models/Llama-2-7b-hf' | ||||||
|  |     base_model_device = 'cuda:0' | ||||||
|  |     chat_model_path = 'ckpts/full' | ||||||
|  |     chat_model_device = 'cuda:1' | ||||||
|  |     stylish_model_path = 'ckpts/stylish' | ||||||
|  |     stylish_model_device = 'cuda:6' | ||||||
|  |  | ||||||
|  |     data = random.sample(prepare_data(data_path), 20) | ||||||
|  |     base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer = prepare_model( | ||||||
|  |         base_model_path, base_model_device, chat_model_path, chat_model_device, stylish_model_path, stylish_model_device | ||||||
|  |     ) | ||||||
|  |     for idx, d in enumerate(data): | ||||||
|  |         results = eval_shift(d, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer) | ||||||
|  |          | ||||||
|  |         images = {key: create_image_from_list(val) for key, val in results.items()} | ||||||
|  |          | ||||||
|  |         combine_width = max([img.width for img in images.values()]) | ||||||
|  |         combine_height = sum([img.height for img in images.values()]) | ||||||
|  |         combine_image = PIL.Image.new('RGB', (combine_width, combine_height), (255, 255, 255)) | ||||||
|  |         y = 0 | ||||||
|  |         for key in ["prompt", "base_model_generate", "stylish_model_generate", "base_model_groundtruth", "stylish_model_groundtruth"]: | ||||||
|  |             img = images[key] | ||||||
|  |             combine_image.paste(img, (0, y)) | ||||||
|  |             y += img.height | ||||||
|  |         combine_image.save(f'outputs/{idx}.png') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
							
								
								
									
										53
									
								
								realign/run_log.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								realign/run_log.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | WARNING:torch.distributed.run: | ||||||
|  | ***************************************** | ||||||
|  | Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.  | ||||||
|  | ***************************************** | ||||||
|  | /home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory | ||||||
|  | /home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory | ||||||
|  | /home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory | ||||||
|  | ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 2) local_rank: 0 (pid: 44359) of binary: /home/tushilong/anaconda3/envs/realign/bin/python | ||||||
|  | Traceback (most recent call last): | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in <module> | ||||||
|  |     sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')()) | ||||||
|  |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper | ||||||
|  |     return f(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main | ||||||
|  |     run(args) | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run | ||||||
|  |     elastic_launch( | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ | ||||||
|  |     return launch_agent(self._config, self._entrypoint, list(args)) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent | ||||||
|  |     raise ChildFailedError( | ||||||
|  | torch.distributed.elastic.multiprocessing.errors.ChildFailedError:  | ||||||
|  | ============================================================ | ||||||
|  | train.py FAILED | ||||||
|  | ------------------------------------------------------------ | ||||||
|  | Failures: | ||||||
|  | [1]: | ||||||
|  |   time      : 2024-03-09_02:08:09 | ||||||
|  |   host      : ubuntu | ||||||
|  |   rank      : 1 (local_rank: 1) | ||||||
|  |   exitcode  : 2 (pid: 44360) | ||||||
|  |   error_file: <N/A> | ||||||
|  |   traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html | ||||||
|  | [2]: | ||||||
|  |   time      : 2024-03-09_02:08:09 | ||||||
|  |   host      : ubuntu | ||||||
|  |   rank      : 2 (local_rank: 2) | ||||||
|  |   exitcode  : 2 (pid: 44361) | ||||||
|  |   error_file: <N/A> | ||||||
|  |   traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html | ||||||
|  | ------------------------------------------------------------ | ||||||
|  | Root Cause (first observed failure): | ||||||
|  | [0]: | ||||||
|  |   time      : 2024-03-09_02:08:09 | ||||||
|  |   host      : ubuntu | ||||||
|  |   rank      : 0 (local_rank: 0) | ||||||
|  |   exitcode  : 2 (pid: 44359) | ||||||
|  |   error_file: <N/A> | ||||||
|  |   traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html | ||||||
|  | ============================================================ | ||||||
							
								
								
									
										0
									
								
								realign/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								realign/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										57
									
								
								realign/utils/draw.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								realign/utils/draw.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | |||||||
|  | from math import ceil | ||||||
|  | from PIL import Image, ImageDraw, ImageFont | ||||||
|  |  | ||||||
|  | def create_image_from_list(data, font_size=16, max_width=800): | ||||||
|  |     # 计算图片的宽度和高度 | ||||||
|  |     full_text = ''.join([item[0] for item in data]) | ||||||
|  |     font = ImageFont.truetype('/usr/local/share/fonts/ttf/dotfiles/MesloLGS/MesloLGS NF Regular.ttf', size=font_size) | ||||||
|  |     total_text_width = ceil(font.getlength(full_text)) | ||||||
|  |     number_of_lines = (total_text_width // max_width) + 1 + 1 # one more line for \n of query | ||||||
|  |     text_width = max_width if total_text_width > max_width else total_text_width | ||||||
|  |     single_height = sum(abs(x) for x in font.getmetrics()) | ||||||
|  |     text_height = (single_height + 10) * (number_of_lines + 1) | ||||||
|  |     image_width = text_width + 20 | ||||||
|  |     image_height = text_height + 20 | ||||||
|  |  | ||||||
|  |     # 创建一张新图片 | ||||||
|  |     image = Image.new('RGB', (image_width, image_height), color='white') | ||||||
|  |     draw = ImageDraw.Draw(image) | ||||||
|  |  | ||||||
|  |     x = 10 | ||||||
|  |     y = 10 | ||||||
|  |     for word in data[0][0].split(): | ||||||
|  |         word_width = font.getlength(word) | ||||||
|  |         if x + word_width > text_width: | ||||||
|  |             x = 10 | ||||||
|  |             y += single_height + 5 | ||||||
|  |         draw.text((x, y), word, font=font, fill='black') | ||||||
|  |         x += word_width + font.getlength(' ')  | ||||||
|  |  | ||||||
|  |     x = 10 | ||||||
|  |     y += single_height + 5 | ||||||
|  |     # 遍历列表中的元组 | ||||||
|  |     for _, item in enumerate(data[1:]): | ||||||
|  |         text: str = item[0]  # 字符串 | ||||||
|  |         color = 'blue' if item[1] == 0 else 'brown' if item[1] < 3 else 'red' | ||||||
|  |         for word in text.split(): | ||||||
|  |             word_width = font.getlength(word) | ||||||
|  |             if x + word_width > text_width: | ||||||
|  |                 x = 10 | ||||||
|  |                 y += single_height + 5 | ||||||
|  |             draw.text((x, y), word, font=font, fill=color) | ||||||
|  |             x += word_width + font.getlength(' ') | ||||||
|  |  | ||||||
|  |     return image | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 示例数据 | ||||||
|  |     data = [("Start", -1), ('Hello', 2), ('World\n', 4), ('Python', 1)] | ||||||
|  |  | ||||||
|  |     data = data + data[1:] * 100 | ||||||
|  |  | ||||||
|  |     # 创建图片 | ||||||
|  |     image = create_image_from_list(data) | ||||||
|  |  | ||||||
|  |     # 保存图片 | ||||||
|  |     image.save('output.png') | ||||||
							
								
								
									
										26
									
								
								realign/utils/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								realign/utils/model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  | from transformers.generation import GenerationConfig | ||||||
|  | from qwen_model.modeling_qwen import QWenModel, QWenLMHeadModel | ||||||
|  |  | ||||||
|  | def load_model(model_path: str, generation_config_path: str=None, tokenizer_path: str=None, device_map: str="auto"): | ||||||
|  |     if tokenizer_path is None: | ||||||
|  |         tokenizer_path = model_path | ||||||
|  |  | ||||||
|  |     if generation_config_path is None: | ||||||
|  |         generation_config_path = model_path | ||||||
|  |  | ||||||
|  |     model_cls = QWenLMHeadModel # if 'chat' in model_path.lower() else QWenModel | ||||||
|  |  | ||||||
|  |     tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) | ||||||
|  |     tokenizer.pad_token = tokenizer.eos_token = '<|endoftext|>' | ||||||
|  |  | ||||||
|  |     generation_config = GenerationConfig.from_pretrained(generation_config_path, trust_remote_code=True) | ||||||
|  |  | ||||||
|  |     generation_config.max_new_tokens = 1024 | ||||||
|  |  | ||||||
|  |     model = model_cls.from_pretrained(model_path, device_map=device_map, trust_remote_code=True, bf16=True) | ||||||
|  |     model.generation_config = generation_config | ||||||
|  |     model.eval() | ||||||
|  |     print(model.generation_config) | ||||||
|  |     print(tokenizer) | ||||||
|  |     return model, tokenizer | ||||||
							
								
								
									
										14
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | |||||||
|  | transformers==4.32.0 | ||||||
|  | datasets==2.14.6 | ||||||
|  | accelerate==0.24.0 | ||||||
|  | tiktoken==0.5.1 | ||||||
|  | einops==0.7.0 | ||||||
|  | transformers_stream_generator==0.0.4 | ||||||
|  | scipy==1.11.3 | ||||||
|  | fairscale | ||||||
|  | sentencepiece | ||||||
|  | fire | ||||||
|  |  | ||||||
|  | numba | ||||||
|  | openpyxl | ||||||
|  | pandas==2.1.1 | ||||||
							
								
								
									
										223
									
								
								test_llama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								test_llama.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,223 @@ | |||||||
|  | import copy | ||||||
|  | from typing import Dict | ||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  | import torch | ||||||
|  | from train.utils.datasets.nyt10_dataset import NYT10StylishDataset | ||||||
|  | from llama.rellama import Method_1 | ||||||
|  |  | ||||||
|  | model_path = "/home/tushilong/hf/models/Llama-2-7b-hf" | ||||||
|  | device = "cuda:1" | ||||||
|  | tokenizer_path = model_path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device) | ||||||
|  | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||||||
|  | tokenizer.pad_token_id = tokenizer.eos_token_id | ||||||
|  | model.eval() | ||||||
|  |  | ||||||
|  | input_ids = torch.tensor([ | ||||||
|  |     [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2,     2,     2, 13866,   338, | ||||||
|  |            385, 15278,   393, 16612,   263,  3414, 29892,  3300,  2859,   411, | ||||||
|  |            385,  1881,   393,  8128,  4340,  3030, 29889, 14350,   263,  2933, | ||||||
|  |            393,  7128,  2486,  1614,  2167,   278,  2009, 29889,    13,    13, | ||||||
|  |           2277, 29937,  2799,  4080, 29901,    13, 29954,  5428,   263,  8424, | ||||||
|  |            310,  1426, 29892,  3113,  1284,   714,   278,  2022, 29899, 29876, | ||||||
|  |           1288,   537,  8220,   297,   372, 29889, 24948,   592,  1058,   338, | ||||||
|  |            278,  2022,   322,   607,   338,   278,  4797,   537, 29889,   450, | ||||||
|  |           1234,   881,   367,   297,  4390,  3402, 29889,    13,    13,  2277, | ||||||
|  |          29937, 10567, 29901,    13,  1576, 21489,  8063,  1919,   607,   338, | ||||||
|  |           5331,   491,   278,   390, 13873,   525, 15864,   290,   381,   435, | ||||||
|  |          10312,  1919,  6502,  7357,  1335,   322,  6502,   390,  1682,   262, | ||||||
|  |           7912,  1919,   338,  3806,   304,  5957,   263,   883,   333,   519, | ||||||
|  |          18766,  1919,   408,   526, 12710,  1919, 24506,   322, 22250,   557, | ||||||
|  |            423,   869,  6629,    13,    13,  2277, 29937, 13291, 29901,     1, | ||||||
|  |           7521,  3126,    13,   426,    13,   259,   376, 10532,  1115,   376, | ||||||
|  |          15864,   290,   381,   435, 10312,  9162,    13,   259,   376, 29876, | ||||||
|  |           1288,   537,  1115,   376, 21489,  8063,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2,     2, 13866,   338,   385, | ||||||
|  |          15278,   393, 16612,   263,  3414, 29892,  3300,  2859,   411,   385, | ||||||
|  |           1881,   393,  8128,  4340,  3030, 29889, 14350,   263,  2933,   393, | ||||||
|  |           7128,  2486,  1614,  2167,   278,  2009, 29889,    13,    13,  2277, | ||||||
|  |          29937,  2799,  4080, 29901,    13, 29954,  5428,   263,  8424,   310, | ||||||
|  |           1426, 29892,  3113,  1284,   714,   278,  2022, 29899, 29876,  1288, | ||||||
|  |            537,  8220,   297,   372, 29889, 24948,   592,  1058,   338,   278, | ||||||
|  |           2022,   322,   607,   338,   278,  4797,   537, 29889,   450,  1234, | ||||||
|  |            881,   367,   297,  4390,  3402, 29889,    13,    13,  2277, 29937, | ||||||
|  |          10567, 29901,    13, 29928,  3496,   637,   674,  1708,  1913, 29948, | ||||||
|  |           3197,  3219,  1973,  4346,   310,  3444,  6454, 22396,   297,   278, | ||||||
|  |          27632, 19016,  1919,  1156,  1183, 13916,   287,   317,  5990, 29880, | ||||||
|  |           1648,   476,  3365,  1212,   578,  1564,   310, 12710, 22600,  1919, | ||||||
|  |          29871, 29955, 29899, 29953,   313, 29896, 29897,  1919, 29871, 29953, | ||||||
|  |          29899, 29941,   869,    13,    13,  2277, 29937, 13291, 29901,     1, | ||||||
|  |           7521,  3126,    13,   426,    13,   259,   376, 10532,  1115,   376, | ||||||
|  |           1913, 29948,  3197,  3219,  1973,  4346,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   376,  3444,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29892, | ||||||
|  |           3300,  2859,   411,   385,  1881,   393,  8128,  4340,  3030, 29889, | ||||||
|  |          14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009, | ||||||
|  |          29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13, 29954, | ||||||
|  |           5428,   263,  8424,   310,  1426, 29892,  3113,  1284,   714,   278, | ||||||
|  |           2022, 29899, 29876,  1288,   537,  8220,   297,   372, 29889, 24948, | ||||||
|  |            592,  1058,   338,   278,  2022,   322,   607,   338,   278,  4797, | ||||||
|  |            537, 29889,   450,  1234,   881,   367,   297,  4390,  3402, 29889, | ||||||
|  |             13,    13,  2277, 29937, 10567, 29901,    13,  4806,   525,   276, | ||||||
|  |            451,  3330,  6392,   322,   591,   437,   302, 29915, 29873,   679, | ||||||
|  |            885,  1965,  1919,  6629,   624,  5876,  1358,  1919,  5069,  4783, | ||||||
|  |           1919, 17374,   624,  5876,  1358,  1919,  2113,  2211, 19025,  1612, | ||||||
|  |           1338,   363, 17362,   297,   278, 29871, 29896, 29929, 29953, 29900, | ||||||
|  |            525, 29879,  1919,  1497,   297,  1234,   304,   263,  1139,  1048, | ||||||
|  |           2020, 23035, 10331,   304,   437,  2253,   297,   278, 16373,  1135, | ||||||
|  |            297,   278,  2787,  6536,   869,    13,    13,  2277, 29937, 13291, | ||||||
|  |          29901,     1,  7521,  3126,    13,   426,    13,   259,   376, 10532, | ||||||
|  |           1115,   376, 17374,   624,  5876,  1358,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   376, 17362,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2, 13866,   338,   385, 15278, | ||||||
|  |            393, 16612,   263,  3414, 29892,  3300,  2859,   411,   385,  1881, | ||||||
|  |            393,  8128,  4340,  3030, 29889, 14350,   263,  2933,   393,  7128, | ||||||
|  |           2486,  1614,  2167,   278,  2009, 29889,    13,    13,  2277, 29937, | ||||||
|  |           2799,  4080, 29901,    13, 29954,  5428,   263,  8424,   310,  1426, | ||||||
|  |          29892,  3113,  1284,   714,   278,  2022, 29899, 29876,  1288,   537, | ||||||
|  |           8220,   297,   372, 29889, 24948,   592,  1058,   338,   278,  2022, | ||||||
|  |            322,   607,   338,   278,  4797,   537, 29889,   450,  1234,   881, | ||||||
|  |            367,   297,  4390,  3402, 29889,    13,    13,  2277, 29937, 10567, | ||||||
|  |          29901,    13,  4013,  1629,  1919,   763,   738,   916,  1919,   278, | ||||||
|  |           1510,   471,  1361,   292,   714,  1612,  1338,   304,  1906, 15783, | ||||||
|  |           6629,   278,  1407,  1900,   297,  3082,  9257,  1919,  6629,   408, | ||||||
|  |          29455,  2164,   491,  4207,   487,   267,   763,  8314,   525, 29879, | ||||||
|  |            360,   420, 19317,   317, 14107,  1049,   322, 14933,   525, 29879, | ||||||
|  |           6290,  1260,   880,  2259,   869,    13,    13,  2277, 29937, 13291, | ||||||
|  |          29901,     1,  7521,  3126,    13,   426,    13,   259,   376, 10532, | ||||||
|  |           1115,   376, 19317,   317, 14107,  1049,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   376,  8314,   376,    13,   500,    13, | ||||||
|  |           7521,     2] | ||||||
|  |     ], device="cuda:1" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | labels = torch.tensor([ | ||||||
|  |     [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2,     2,     2, -100,   338, | ||||||
|  |            385, 15278,   393, 16612,   263,  3414, 29892,  3300,  2859,   411, | ||||||
|  |            385,  1881,   393,  8128,  4340,  3030, 29889, 14350,   263,  2933, | ||||||
|  |            393,  7128,  2486,  1614,  2167,   278,  2009, 29889,    13,    13, | ||||||
|  |           2277, 29937,  2799,  4080, 29901,    13, 29954,  5428,   263,  8424, | ||||||
|  |            310,  1426, 29892,  3113,  1284,   714,   278,  2022, 29899, 29876, | ||||||
|  |           1288,   537,  8220,   297,   372, 29889, 24948,   592,  1058,   338, | ||||||
|  |            278,  2022,   322,   607,   338,   -100,  4797,   537, 29889,   450, | ||||||
|  |           1234,   881,   367,   297,  4390,  3402, 29889,    13,    13,  2277, | ||||||
|  |          29937, 10567, 29901,    13,  1576, -100,  8063,  1919,   607,   338, | ||||||
|  |           5331,   491,   278,   390, 13873,   525, 15864,   290,   381,   435, | ||||||
|  |          10312,  1919,  6502,  7357,  1335,   322,  6502,   390,  1682,   262, | ||||||
|  |           7912,  1919,   338,  -100,   304,  5957,   263,   883,   333,   519, | ||||||
|  |          18766,  1919,   408,   526, 12710,  1919, 24506,   322, 22250,   557, | ||||||
|  |            423,   869,  6629,    13,    13,  2277, 29937, 13291, 29901,     1, | ||||||
|  |           7521,  3126,    13,   426,    13,   259,   376, 10532,  1115,   376, | ||||||
|  |          15864,   290,   381,   435, 10312,  -100,    13,   259,   376, 29876, | ||||||
|  |           1288,   -100,  1115,  -100, 21489,  8063,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2,     2, -100,   338,   385, | ||||||
|  |          15278,   393, 16612,   263,  3414, 29892,  3300,  2859,   411,   385, | ||||||
|  |           1881,   393,  8128,  4340,  3030, 29889, 14350,   263,  2933,   393, | ||||||
|  |           7128,  2486,  1614,  2167,   278,  2009, 29889,    13,    13,  2277, | ||||||
|  |          29937,  2799,  4080, 29901,    13, 29954,  5428,   263,  8424,   310, | ||||||
|  |           1426, 29892,  3113,  1284,   714,   -100,  2022, 29899, 29876,  1288, | ||||||
|  |            537,  8220,   297,   372, 29889, 24948,   592,  1058,   338,   278, | ||||||
|  |           2022,   322,   607,   338,   278,  4797,   537, 29889,   450,  1234, | ||||||
|  |            881,   367,   -100,  4390,  3402, 29889,    13,    13,  2277, 29937, | ||||||
|  |          10567, 29901,    13, 29928,  3496,   637,   674,  1708,  1913, 29948, | ||||||
|  |           3197,  3219,  1973,  4346,   310,  3444,  6454, 22396,   297,   278, | ||||||
|  |          27632, 19016,  1919,  1156,  1183, 13916,   287,   317,  5990, 29880, | ||||||
|  |           1648,   476,  3365,  1212,   578,  1564,   310, 12710, 22600,  1919, | ||||||
|  |          29871, 29955, 29899, 29953,   313, 29896, 29897,  1919, 29871, 29953, | ||||||
|  |          29899, 29941,   869,    13,    -100,  2277, 29937, 13291, 29901,     1, | ||||||
|  |           7521,  -100,    13,   426,    13,   259,   376, 10532,  1115,   376, | ||||||
|  |           1913, 29948,  3197,  3219,  1973,  -100,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   -100,  3444,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2, -100,   338,   385, 15278,   393, 16612,   263,  3414, 29892, | ||||||
|  |           3300,  2859,   411,   385,  1881,   393,  8128,  4340,  3030, 29889, | ||||||
|  |          14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009, | ||||||
|  |          29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13, 29954, | ||||||
|  |           5428,   263,  8424,   310,  1426, 29892,  3113,  1284,   714,   278, | ||||||
|  |           2022, 29899, 29876,  1288,   537,  8220,   297,   372, 29889, 24948, | ||||||
|  |            592,  1058,   338,   278,  2022,   322,   607,   338,   278,  4797, | ||||||
|  |            537, 29889,   450,  -100,   881,   367,   297,  4390,  3402, 29889, | ||||||
|  |             13,    13,  2277, 29937, 10567, 29901,    13,  4806,   525,   276, | ||||||
|  |            451,  3330,  6392,   322,   591,   437,   302, 29915, 29873,   679, | ||||||
|  |            885,  1965,  1919,  6629,   624,  5876,  1358,  1919,  5069,  4783, | ||||||
|  |           1919, 17374,   624,  5876,  1358,  1919,  2113,  2211, 19025,  1612, | ||||||
|  |           1338,   363, 17362,   297,   278, 29871, 29896, 29929, 29953, 29900, | ||||||
|  |            525, 29879,  1919,  1497,   297,  1234,   304,   263,  1139,  1048, | ||||||
|  |           2020, 23035, 10331,   304,   437,  2253,   297,   278, 16373,  1135, | ||||||
|  |            297,   278,  2787,  6536,   869,    13,    13,  2277, 29937, 13291, | ||||||
|  |          29901,     1,  7521,  3126,    -100,   426,    13,   259,   376, 10532, | ||||||
|  |           1115,   376, 17374,   624,  5876,  -100,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   376, -100,   376,    13,   500,    13, | ||||||
|  |           7521,     2], | ||||||
|  |         [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2, | ||||||
|  |              2,     2,     2,     2,     2,     2, -100,   338,   385, 15278, | ||||||
|  |            393, 16612,   263,  3414, 29892,  3300,  2859,   411,   385,  1881, | ||||||
|  |            393,  8128,  4340,  3030, 29889, 14350,   263,  2933,   393,  7128, | ||||||
|  |           2486,  1614,  2167,   278,  2009, 29889,    13,    13,  2277, 29937, | ||||||
|  |           2799,  4080, 29901,    13, 29954,  5428,   263,  8424,   310,  1426, | ||||||
|  |          29892,  3113,  1284,   714,   278,  2022, 29899, 29876,  1288,   537, | ||||||
|  |           8220,   297,   372, 29889, 24948,   592,  1058,   338,   278,  2022, | ||||||
|  |            322,   607,   338,   278,  4797,   -100, 29889,   450,  1234,   881, | ||||||
|  |            367,   297,  4390,  3402, 29889,    13,    13,  2277, 29937, 10567, | ||||||
|  |          29901,    13,  4013,  1629,  1919,   763,   738,   916,  1919,   278, | ||||||
|  |           1510,   471,  1361,   292,   714,  1612,  1338,   304,  1906, 15783, | ||||||
|  |           6629,   278,  1407,  1900,   297,  3082,  9257,  1919,  6629,   408, | ||||||
|  |          29455,  2164,   491,  4207,   487,   267,   763,  8314,   525, 29879, | ||||||
|  |            360,   420, 19317,   317, 14107,  1049,   322, 14933,   525, 29879, | ||||||
|  |           6290,  1260,   880,  2259,   869,    13,    13,  2277, 29937, 13291, | ||||||
|  |          29901,     1,  7521,  3126,    13,  -100,    13,   259,   376, 10532, | ||||||
|  |           1115,   376, 19317,   317, 14107,  -100,  9162,    13,   259,   376, | ||||||
|  |          29876,  1288,   537,  1115,   376,  8314,   376,    13,   500,    13, | ||||||
|  |           7521,     2] | ||||||
|  |     ], device="cuda:1" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | attention_mask = torch.tensor( | ||||||
|  | [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||||||
|  |         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||||||
|  |         [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||||||
|  |         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|  |          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device="cuda:1" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | outputs = model(input_ids=input_ids, attention_mask=attention_mask,)# labels=labels) | ||||||
|  | assert torch.isnan(outputs.logits).sum() == 0 | ||||||
|  | # print(outputs.loss) | ||||||
							
								
								
									
										24
									
								
								test_prediction.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								test_prediction.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | |||||||
|  | import json | ||||||
|  | from realign.eval_acc import LM | ||||||
|  |  | ||||||
|  |  | ||||||
|  | device = 'cuda:1' | ||||||
|  | model_path = './ckpts/stylish' | ||||||
|  | data_path = './data/nyt10/nyt10_test.txt' | ||||||
|  |  | ||||||
|  | prompt_template = ( | ||||||
|  |     "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
|  |     "Write a response that appropriately completes the request.\n\n" | ||||||
|  |     "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | data = [json.loads(line) for line in open(data_path, 'r').readlines()] | ||||||
|  | data = [dp for dp in data if '/people/person/nationality' in dp['relation']] | ||||||
|  |  | ||||||
|  | lm = LM(device, model_path) | ||||||
|  |  | ||||||
|  | for i in range(len(data[:10])): | ||||||
|  |     prompt = prompt_template.format(text=data[i]['text']) | ||||||
|  |     response = lm.chat(prompt) | ||||||
|  |     print(response) | ||||||
							
								
								
									
										0
									
								
								train/configs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/configs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										11
									
								
								train/configs/finetune_arguments.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								train/configs/finetune_arguments.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | |||||||
|  | from dataclasses import dataclass, field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### 定义一些配置信息 | ||||||
|  | @dataclass | ||||||
|  | class FinetuneArguments: | ||||||
|  |     model_name: str = field() | ||||||
|  |     data_path: str = field() | ||||||
|  |     train_size: int = field(default=-1) | ||||||
|  |     test_size: int = field(default=100) | ||||||
|  |     max_len: int = field(default=1024) | ||||||
							
								
								
									
										4
									
								
								train/configs/fsdp/internlm_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/internlm_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | { | ||||||
|  |     "fsdp_transformer_layer_cls_to_wrap": ["InternLMDecoderLayer"], | ||||||
|  |     "limit_all_gathers": true | ||||||
|  | } | ||||||
							
								
								
									
										4
									
								
								train/configs/fsdp/llama2_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/llama2_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | { | ||||||
|  |     "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], | ||||||
|  |     "limit_all_gathers": true | ||||||
|  | } | ||||||
							
								
								
									
										4
									
								
								train/configs/fsdp/qwen_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/qwen_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | { | ||||||
|  |     "fsdp_transformer_layer_cls_to_wrap": ["QWenBlock"], | ||||||
|  |     "limit_all_gathers": true | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								train/configs/logger_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								train/configs/logger_config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | logger_config = { | ||||||
|  |     'version': 1, | ||||||
|  |     'formatters': { | ||||||
|  |         'simple': { | ||||||
|  |             'format': f"%(asctime)s %(name)s %(levelname)s: %(message)s", | ||||||
|  |             'datefmt': '%Y-%m-%d %H:%M:%S', | ||||||
|  |         }, | ||||||
|  |         # 其他的 formatter | ||||||
|  |     }, | ||||||
|  |     'handlers': { | ||||||
|  |         'console': { | ||||||
|  |             'class': 'logging.StreamHandler', | ||||||
|  |             'level': 'DEBUG', | ||||||
|  |             'formatter': 'simple', | ||||||
|  |         }, | ||||||
|  |         # 其他的 handler | ||||||
|  |     }, | ||||||
|  |     'loggers':{ | ||||||
|  |         # 仅输出到控制台,使用 StreamLogger | ||||||
|  |         'StreamLogger': { | ||||||
|  |             'handlers': ['console'], | ||||||
|  |             'level': 'DEBUG', | ||||||
|  |         }, | ||||||
|  |         # 其他的 Logger | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										136
									
								
								train/run_log.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								train/run_log.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,136 @@ | |||||||
|  | WARNING:torch.distributed.run: | ||||||
|  | ***************************************** | ||||||
|  | Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.  | ||||||
|  | ***************************************** | ||||||
|  | Training model with params: | ||||||
|  | base_model: /home/tushilong/hf/models/Llama-2-7b-hf | ||||||
|  | output_dir: ../ckpts/stylish | ||||||
|  | micro_batch_size: 2 | ||||||
|  | gradient_accumulation_steps: 1 | ||||||
|  | train_batch_size: 2 | ||||||
|  | gradient_checkpointing: True | ||||||
|  | num_epochs: 1 | ||||||
|  | learning_rate: 2e-05 | ||||||
|  | weight_decay: 0.0001 | ||||||
|  | warmup_ratio: 0.06 | ||||||
|  | deepspeed_config: None | ||||||
|  | fsdp: shard_grad_op auto_wrap offload | ||||||
|  | fsdp_config: ./configs/fsdp/llama2_fsdp_config.json | ||||||
|  | smart_embedding: False | ||||||
|  | wandb_project:  | ||||||
|  | wandb_run_name:  | ||||||
|  | wandb_watch:  | ||||||
|  | wandb_log_model:  | ||||||
|  | resume_from_checkpoint: False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.52s/it] | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.51s/it] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.40s/it] | ||||||
|  |  | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.39s/it] | ||||||
|  |  | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.45s/it] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.17s/it] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.36s/it] | ||||||
|  | Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher. | ||||||
|  | StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||||
|  | You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||||
|  | StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||||
|  |  | ||||||
|  |   0%|          | 0/167 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||||
|  | StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||||
|  | You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||||
|  |  | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||||
|  |  | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A | ||||||
|  | Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  4.90it/s] | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  4.20it/s] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.97it/s] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.55it/s] | ||||||
|  |  | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.33it/s] | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.88it/s] | ||||||
|  | `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||||
|  | `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Loading checkpoint shards:  50%|█████     | 1/2 [00:14<00:14, 14.71s/it][A | ||||||
|  |  | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  8.65s/it][A | ||||||
|  | Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  9.55s/it] | ||||||
|  | `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||||
|  | Traceback (most recent call last): | ||||||
|  |   File "/home/tushilong/code/realign/train/train.py", line 167, in <module> | ||||||
|  |     fire.Fire(train) | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 141, in Fire | ||||||
|  |     component_trace = _Fire(component, args, parsed_flag_args, context, name) | ||||||
|  |                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire | ||||||
|  |     component, remaining_args = _CallAndUpdateTrace( | ||||||
|  |                                 ^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace | ||||||
|  |     component = fn(*varargs, **kwargs) | ||||||
|  |                 ^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/code/realign/train/train.py", line 160, in train | ||||||
|  |     trainer.train(resume_from_checkpoint=resume_from_checkpoint) | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train | ||||||
|  |     return inner_training_loop( | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop | ||||||
|  |     tr_loss_step = self.training_step(model, inputs) | ||||||
|  |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step | ||||||
|  |     loss = self.compute_loss(model, inputs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2707, in compute_loss | ||||||
|  |     outputs = model(**inputs) | ||||||
|  |               ^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl | ||||||
|  |     return forward_call(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward | ||||||
|  |     return model_forward(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ | ||||||
|  |     return convert_to_fp32(self.model_forward(*args, **kwargs)) | ||||||
|  |                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast | ||||||
|  |     return func(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward | ||||||
|  |     return model_forward(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ | ||||||
|  |     return convert_to_fp32(self.model_forward(*args, **kwargs)) | ||||||
|  |                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast | ||||||
|  |     return func(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward | ||||||
|  |     output = self._fsdp_wrapped_module(*args, **kwargs) | ||||||
|  |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl | ||||||
|  |     return forward_call(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/code/realign/llama/rellama.py", line 129, in forward | ||||||
|  |     assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" | ||||||
|  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  | AssertionError: target_logits has nan: 10752000 | ||||||
|  | WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20205 closing signal SIGTERM | ||||||
|  | WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20206 closing signal SIGTERM | ||||||
|  | ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 20207) of binary: /home/tushilong/anaconda3/envs/realign/bin/python | ||||||
|  | Traceback (most recent call last): | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in <module> | ||||||
|  |     sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')()) | ||||||
|  |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper | ||||||
|  |     return f(*args, **kwargs) | ||||||
|  |            ^^^^^^^^^^^^^^^^^^ | ||||||
							
								
								
									
										168
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,168 @@ | |||||||
|  | import random | ||||||
|  | import sys | ||||||
|  | import os | ||||||
|  | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6' | ||||||
|  |  | ||||||
|  | import fire | ||||||
|  | import torch | ||||||
|  | torch.autograd.set_detect_anomaly(True) | ||||||
|  | import transformers | ||||||
|  | from transformers import set_seed | ||||||
|  | set_seed(15) | ||||||
|  | from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset | ||||||
|  | from trainers import FSDPTrainingArguments, FSDPTrainer | ||||||
|  | from transformers import AutoTokenizer, AutoConfig | ||||||
|  | from transformers import AutoModelForCausalLM, LlamaForCausalLM | ||||||
|  | from llama import Method_1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def train( | ||||||
|  |     # model/data params | ||||||
|  |     base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf', | ||||||
|  |     data_path: str = '../data/nyt10/nyt10_train.txt', | ||||||
|  |     output_dir: str = '../ckpts/stylish', | ||||||
|  |      | ||||||
|  |     # training hyperparams | ||||||
|  |     do_train: bool = True, | ||||||
|  |     micro_batch_size: int = 2, | ||||||
|  |     gradient_accumulation_steps: int = 1, | ||||||
|  |     gradient_checkpointing: bool = True, | ||||||
|  |     num_epochs: int = 1, | ||||||
|  |     save_steps: int = 500, | ||||||
|  |     learning_rate: float = 2e-5, | ||||||
|  |     lr_scheduler_type: str = 'cosine', | ||||||
|  |     weight_decay: float = 1e-4, | ||||||
|  |     warmup_ratio: float = 0.06, | ||||||
|  |     deepspeed_config: str = None, | ||||||
|  |     fsdp: str = 'shard_grad_op auto_wrap offload', | ||||||
|  |     fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json', | ||||||
|  |     smart_embedding: bool = False, | ||||||
|  |  | ||||||
|  |     # evaluating hyperparams | ||||||
|  |     do_eval: bool = False, | ||||||
|  |     val_set_size: int = 1000, | ||||||
|  |     eval_batch_size: int = 4, | ||||||
|  |      | ||||||
|  |     # wandb params | ||||||
|  |     wandb_project: str = "", | ||||||
|  |     wandb_run_name: str = "", | ||||||
|  |     wandb_watch: str = "",  # options: false | gradients | all | ||||||
|  |     wandb_log_model: str = "",  # options: false | true | ||||||
|  |     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter | ||||||
|  | ): | ||||||
|  |     if int(os.environ.get("LOCAL_RANK", 0)) == 0: | ||||||
|  |         print( | ||||||
|  |             f"Training model with params:\n" | ||||||
|  |             f"base_model: {base_model}\n" | ||||||
|  |             f"output_dir: {output_dir}\n" | ||||||
|  |             f"micro_batch_size: {micro_batch_size}\n" | ||||||
|  |             f"gradient_accumulation_steps: {gradient_accumulation_steps}\n" | ||||||
|  |             f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n" | ||||||
|  |             f"gradient_checkpointing: {gradient_checkpointing}\n" | ||||||
|  |             f"num_epochs: {num_epochs}\n" | ||||||
|  |             f"learning_rate: {learning_rate}\n" | ||||||
|  |             f"weight_decay: {weight_decay}\n" | ||||||
|  |             f"warmup_ratio: {warmup_ratio}\n" | ||||||
|  |             f"deepspeed_config: {deepspeed_config}\n" | ||||||
|  |             f"fsdp: {fsdp}\n" | ||||||
|  |             f"fsdp_config: {fsdp_config}\n" | ||||||
|  |             f"smart_embedding: {smart_embedding}\n" | ||||||
|  |             f"wandb_project: {wandb_project}\n" | ||||||
|  |             f"wandb_run_name: {wandb_run_name}\n" | ||||||
|  |             f"wandb_watch: {wandb_watch}\n" | ||||||
|  |             f"wandb_log_model: {wandb_log_model}\n" | ||||||
|  |             f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" | ||||||
|  |         ) | ||||||
|  |     assert ( | ||||||
|  |         not (deepspeed_config and fsdp) | ||||||
|  |     ), "Can not specified both deepspeed_config and fsdp_config" | ||||||
|  |      | ||||||
|  |     # training arguments | ||||||
|  |     bf16 = True # torch.cuda.get_device_capability()[0] >= 8 | ||||||
|  |     fp16 = not bf16 | ||||||
|  |      | ||||||
|  |     tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | ||||||
|  |     tokenizer.pad_token_id = tokenizer.eos_token_id | ||||||
|  |  | ||||||
|  |     # model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True) | ||||||
|  |     # model_dev_id = int(os.environ.get("LOCAL_RANK", 0)) | ||||||
|  |  | ||||||
|  |     model = Method_1.from_pretrained(base_model) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     # Check if parameter passed or if set within environ | ||||||
|  |     # use_wandb = len(wandb_project) > 0 or ( | ||||||
|  |     #     "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 | ||||||
|  |     # ) | ||||||
|  |     use_wandb = False | ||||||
|  |     # Only overwrite environ if wandb param passed | ||||||
|  |     # if len(wandb_project) > 0: | ||||||
|  |     #     os.environ["WANDB_PROJECT"] = wandb_project | ||||||
|  |     # if len(wandb_watch) > 0: | ||||||
|  |     #     os.environ["WANDB_WATCH"] = wandb_watch | ||||||
|  |     # if len(wandb_log_model) > 0: | ||||||
|  |     #     os.environ["WANDB_LOG_MODEL"] = wandb_log_model | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     # train_data = NYT10FullDataset(data_path, tokenizer) | ||||||
|  |     # train_data = NYT10StylishDataset(data_path, tokenizer, 30) | ||||||
|  |     train_data = NYT10StylishDataset(data_path, tokenizer, 1000) | ||||||
|  |     val_data = None | ||||||
|  |          | ||||||
|  |     training_args = FSDPTrainingArguments( | ||||||
|  |         use_ffd_sampler=True, | ||||||
|  |         output_dir=output_dir, | ||||||
|  |         no_cuda=not torch.cuda.is_available(), | ||||||
|  |         seed=15, | ||||||
|  |         data_seed=15, | ||||||
|  |         do_train=do_train, | ||||||
|  |         num_train_epochs=num_epochs, | ||||||
|  |         optim="adamw_torch", | ||||||
|  |         learning_rate=learning_rate, | ||||||
|  |         lr_scheduler_type=lr_scheduler_type, | ||||||
|  |         per_device_train_batch_size=micro_batch_size, | ||||||
|  |         gradient_accumulation_steps=gradient_accumulation_steps, | ||||||
|  |         warmup_ratio=warmup_ratio, | ||||||
|  |         weight_decay=weight_decay, | ||||||
|  |         half_precision_backend="auto", | ||||||
|  |         fp16=fp16, | ||||||
|  |         bf16=bf16, | ||||||
|  |         adam_beta1=0.9, | ||||||
|  |         adam_beta2=0.95, | ||||||
|  |         save_strategy="steps", | ||||||
|  |         save_steps=save_steps, | ||||||
|  |         save_total_limit=2, | ||||||
|  |         logging_steps=1, | ||||||
|  |         report_to= "none", # "wandb" if use_wandb else None, | ||||||
|  |         run_name=None, #wandb_run_name if use_wandb else None, | ||||||
|  |         deepspeed=deepspeed_config, | ||||||
|  |         fsdp=fsdp, | ||||||
|  |         fsdp_config=fsdp_config, | ||||||
|  |         gradient_checkpointing=gradient_checkpointing, | ||||||
|  |         do_eval=do_eval and val_set_size > 0, | ||||||
|  |         evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no", | ||||||
|  |         eval_steps=save_steps, | ||||||
|  |         per_device_eval_batch_size=eval_batch_size, | ||||||
|  |         # group_by_length=True, | ||||||
|  |     )  | ||||||
|  |      | ||||||
|  |     trainer = FSDPTrainer( | ||||||
|  |         model=model, | ||||||
|  |         args=training_args, | ||||||
|  |         train_dataset=train_data, | ||||||
|  |         eval_dataset=val_data, | ||||||
|  |         data_collator=transformers.DataCollatorForSeq2Seq( | ||||||
|  |             tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     trainer.train(resume_from_checkpoint=resume_from_checkpoint) | ||||||
|  |  | ||||||
|  |     trainer.save_model(output_dir) | ||||||
|  |     tokenizer.save_pretrained(output_dir) | ||||||
|  |      | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     fire.Fire(train) | ||||||
|  |      | ||||||
							
								
								
									
										2
									
								
								train/trainers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								train/trainers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | from .fsdp_training_args import FSDPTrainingArguments | ||||||
|  | from .fsdp_trainer import FSDPTrainer | ||||||
							
								
								
									
										161
									
								
								train/trainers/ffd_sampler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								train/trainers/ffd_sampler.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,161 @@ | |||||||
|  | from typing import Optional, List | ||||||
|  |  | ||||||
|  | import torch.distributed as dist | ||||||
|  | from torch.utils.data import Sampler | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import numba | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @numba.njit | ||||||
|  | def ffd(a: np.ndarray, c: int): | ||||||
|  |     # First-fit-decreasing bin packing | ||||||
|  |     # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing | ||||||
|  |  | ||||||
|  |     a = np.sort(a)[::-1] | ||||||
|  |     bins = [] | ||||||
|  |     for size in a: | ||||||
|  |         add_new = True | ||||||
|  |         for idx in range(len(bins)): | ||||||
|  |             if bins[idx] >= size: | ||||||
|  |                 bins[idx] -= size | ||||||
|  |                 add_new = False | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |         if add_new: | ||||||
|  |             bins.append(c - size) | ||||||
|  |  | ||||||
|  |     return len(bins) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @numba.njit | ||||||
|  | def ffd_with_result(a: np.ndarray, c: int, start_index: int): | ||||||
|  |     # First-fit-decreasing bin packing (with result return) | ||||||
|  |  | ||||||
|  |     indices = np.argsort(a)[::-1] | ||||||
|  |     a = a[indices] | ||||||
|  |  | ||||||
|  |     bins = [] | ||||||
|  |     bins_result = [] | ||||||
|  |     for a_id, size in enumerate(a): | ||||||
|  |         add_new = True | ||||||
|  |         for idx in range(len(bins)): | ||||||
|  |             if bins[idx] >= size: | ||||||
|  |                 bins[idx] -= size | ||||||
|  |                 bins_result[idx].append(indices[a_id] + start_index) | ||||||
|  |                 add_new = False | ||||||
|  |                 break | ||||||
|  |  | ||||||
|  |         if add_new: | ||||||
|  |             bins.append(c - size) | ||||||
|  |             bins_result.append([indices[a_id] + start_index]) | ||||||
|  |  | ||||||
|  |     return bins_result | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @numba.njit | ||||||
|  | def allocate(lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int): | ||||||
|  |     # Dynamic batch allocator, similar to Multifit | ||||||
|  |     # https://en.wikipedia.org/wiki/Multifit_algorithm | ||||||
|  |     # ~96.4% efficiency on OpenChat training set (2048 ctx len) | ||||||
|  |  | ||||||
|  |     s = 0 | ||||||
|  |     start_index = 0 | ||||||
|  |     result = [] | ||||||
|  |  | ||||||
|  |     while True: | ||||||
|  |         # binary search [l, r) | ||||||
|  |         l = 1 | ||||||
|  |         r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") | ||||||
|  |  | ||||||
|  |         while r - l > 1: | ||||||
|  |             m = (l + r) // 2 | ||||||
|  |             if ffd(lengths[start_index: start_index + m], c) <= n: | ||||||
|  |                 l = m | ||||||
|  |             else: | ||||||
|  |                 r = m | ||||||
|  |  | ||||||
|  |         # use length l | ||||||
|  |         batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index) | ||||||
|  |         if len(batch) < n: | ||||||
|  |             break | ||||||
|  |  | ||||||
|  |         start_index += l | ||||||
|  |         s = lengths_cumsum[start_index - 1] | ||||||
|  |  | ||||||
|  |         # add local rank | ||||||
|  |         result.append(batch[rank]) | ||||||
|  |  | ||||||
|  |     return result, s / max(1, len(result) * c * n)  # Avoid division by zero | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FFDDistributedBatchSampler(Sampler): | ||||||
|  |     """Unpadded length sampling using FFD (First-fit-decreasing bin packing). | ||||||
|  |        Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.""" | ||||||
|  |      | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         batch_max_length: int, | ||||||
|  |         lengths: List[int], | ||||||
|  |         num_replicas: Optional[int] = None, | ||||||
|  |         rank: Optional[int] = None, | ||||||
|  |         seed: int = 0, | ||||||
|  |     ): | ||||||
|  |         # Get rank | ||||||
|  |         if num_replicas is None: | ||||||
|  |             if not dist.is_available(): | ||||||
|  |                 raise RuntimeError("Requires distributed package to be available") | ||||||
|  |             num_replicas = dist.get_world_size() | ||||||
|  |         if rank is None: | ||||||
|  |             if not dist.is_available(): | ||||||
|  |                 raise RuntimeError("Requires distributed package to be available") | ||||||
|  |             rank = dist.get_rank() | ||||||
|  |  | ||||||
|  |         self.num_replicas = num_replicas | ||||||
|  |         self.rank = rank | ||||||
|  |         self.seed = seed | ||||||
|  |  | ||||||
|  |         self.batch_max_length = batch_max_length | ||||||
|  |         self.lengths = lengths | ||||||
|  |         assert isinstance(self.lengths, np.ndarray) | ||||||
|  |  | ||||||
|  |         self.epoch = 0 | ||||||
|  |  | ||||||
|  |         # statistics | ||||||
|  |         self.total_epochs = 0 | ||||||
|  |         self.total_efficiency = 0 | ||||||
|  |  | ||||||
|  |     def set_epoch(self, epoch: int): | ||||||
|  |         self.epoch = epoch | ||||||
|  |  | ||||||
|  |     def generate_batches(self, set_stats=False): | ||||||
|  |         indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths)) | ||||||
|  |  | ||||||
|  |         lengths = self.lengths[indices] | ||||||
|  |         lengths_cumsum = np.cumsum(lengths) | ||||||
|  |  | ||||||
|  |         batches, efficiency = allocate(lengths=lengths, | ||||||
|  |                            lengths_cumsum=lengths_cumsum, | ||||||
|  |                            rank=self.rank, | ||||||
|  |                            c=self.batch_max_length, | ||||||
|  |                            n=self.num_replicas) | ||||||
|  |          | ||||||
|  |         batches = [indices[batch] for batch in batches] | ||||||
|  |  | ||||||
|  |         # statistics | ||||||
|  |         if set_stats: | ||||||
|  |             self.total_epochs += 1 | ||||||
|  |             self.total_efficiency += efficiency | ||||||
|  |          | ||||||
|  |         return batches | ||||||
|  |      | ||||||
|  |     def __iter__(self): | ||||||
|  |         batches = self.generate_batches(set_stats=True) | ||||||
|  |         return iter(batches) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         batches = self.generate_batches() | ||||||
|  |         return len(batches) | ||||||
|  |  | ||||||
|  |     def efficiency(self): | ||||||
|  |         return self.total_efficiency / self.total_epochs | ||||||
							
								
								
									
										925
									
								
								train/trainers/fsdp_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										925
									
								
								train/trainers/fsdp_trainer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,925 @@ | |||||||
|  | import sys | ||||||
|  | import os | ||||||
|  | from typing import Optional | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  |  | ||||||
|  | import transformers | ||||||
|  | from transformers.trainer import * | ||||||
|  |  | ||||||
|  | from .ffd_sampler import FFDDistributedBatchSampler | ||||||
|  | from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FSDPTrainer(transformers.Trainer): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model: Union[PreTrainedModel, nn.Module] = None, | ||||||
|  |         args: TrainingArguments = None, | ||||||
|  |         data_collator: Optional[DataCollator] = None, | ||||||
|  |         train_dataset: Optional[Dataset] = None, | ||||||
|  |         eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, | ||||||
|  |         tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||||||
|  |         model_init: Optional[Callable[[], PreTrainedModel]] = None, | ||||||
|  |         compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | ||||||
|  |         callbacks: Optional[List[TrainerCallback]] = None, | ||||||
|  |         optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | ||||||
|  |         preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | ||||||
|  |     ): | ||||||
|  |         if args is None: | ||||||
|  |             output_dir = "tmp_trainer" | ||||||
|  |             logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") | ||||||
|  |             args = TrainingArguments(output_dir=output_dir) | ||||||
|  |         self.args = args | ||||||
|  |         # Seed must be set before instantiating the model when using model | ||||||
|  |         enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) | ||||||
|  |         self.hp_name = None | ||||||
|  |         self.deepspeed = None | ||||||
|  |         self.is_in_train = False | ||||||
|  |  | ||||||
|  |         self.create_accelerator_and_postprocess() | ||||||
|  |  | ||||||
|  |         # memory metrics - must set up as early as possible | ||||||
|  |         self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) | ||||||
|  |         self._memory_tracker.start() | ||||||
|  |  | ||||||
|  |         # set the correct log level depending on the node | ||||||
|  |         log_level = args.get_process_log_level() | ||||||
|  |         logging.set_verbosity(log_level) | ||||||
|  |  | ||||||
|  |         # force device and distributed setup init explicitly | ||||||
|  |         args._setup_devices | ||||||
|  |  | ||||||
|  |         if model is None: | ||||||
|  |             if model_init is not None: | ||||||
|  |                 self.model_init = model_init | ||||||
|  |                 model = self.call_model_init() | ||||||
|  |             else: | ||||||
|  |                 raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") | ||||||
|  |         else: | ||||||
|  |             if model_init is not None: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" | ||||||
|  |                     " overwrite your model when calling the `train` method. This will become a fatal error in the next" | ||||||
|  |                     " release.", | ||||||
|  |                     FutureWarning, | ||||||
|  |                 ) | ||||||
|  |             self.model_init = model_init | ||||||
|  |  | ||||||
|  |         if model.__class__.__name__ in MODEL_MAPPING_NAMES: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " | ||||||
|  |                 "computes hidden states and does not accept any labels. You should choose a model with a head " | ||||||
|  |                 "suitable for your task like any of the `AutoModelForXxx` listed at " | ||||||
|  |                 "https://huggingface.co/docs/transformers/model_doc/auto." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: | ||||||
|  |             self.is_model_parallel = True | ||||||
|  |         else: | ||||||
|  |             self.is_model_parallel = False | ||||||
|  |  | ||||||
|  |         if getattr(model, "hf_device_map", None) is not None: | ||||||
|  |             devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] | ||||||
|  |             if len(devices) > 1: | ||||||
|  |                 self.is_model_parallel = True | ||||||
|  |             else: | ||||||
|  |                 self.is_model_parallel = self.args.device != torch.device(devices[0]) | ||||||
|  |  | ||||||
|  |             # warn users | ||||||
|  |             logger.info( | ||||||
|  |                 "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" | ||||||
|  |                 " to `True` to avoid any unexpected behavior such as device placement mismatching." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # At this stage the model is already loaded | ||||||
|  |         if getattr(model, "is_quantized", False): | ||||||
|  |             if getattr(model, "_is_quantized_training_enabled", False): | ||||||
|  |                 logger.info( | ||||||
|  |                     "The model is loaded in 8-bit precision. To train this model you need to add additional modules" | ||||||
|  |                     " inside the model such as adapters using `peft` library and freeze the model weights. Please" | ||||||
|  |                     " check " | ||||||
|  |                     " the examples in https://github.com/huggingface/peft for more details." | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit" | ||||||
|  |                     " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         # Setup Sharded DDP training | ||||||
|  |         self.sharded_ddp = None | ||||||
|  |         if len(args.sharded_ddp) > 0: | ||||||
|  |             if self.is_deepspeed_enabled: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." | ||||||
|  |                 ) | ||||||
|  |             if len(args.fsdp) > 0: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." | ||||||
|  |                 ) | ||||||
|  |             if args.parallel_mode != ParallelMode.DISTRIBUTED: | ||||||
|  |                 raise ValueError("Using sharded DDP only works in distributed training.") | ||||||
|  |             elif not is_fairscale_available(): | ||||||
|  |                 raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") | ||||||
|  |             elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: | ||||||
|  |                 raise ImportError( | ||||||
|  |                     "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " | ||||||
|  |                     f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." | ||||||
|  |                 ) | ||||||
|  |             elif ShardedDDPOption.SIMPLE in args.sharded_ddp: | ||||||
|  |                 self.sharded_ddp = ShardedDDPOption.SIMPLE | ||||||
|  |             elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: | ||||||
|  |                 self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 | ||||||
|  |             elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: | ||||||
|  |                 self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 | ||||||
|  |  | ||||||
|  |         self.fsdp = None | ||||||
|  |         if len(args.fsdp) > 0: | ||||||
|  |             if self.is_deepspeed_enabled: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." | ||||||
|  |                 ) | ||||||
|  |             if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: | ||||||
|  |                 raise ValueError("Using fsdp only works in distributed training.") | ||||||
|  |  | ||||||
|  |             # dep_version_check("torch>=1.12.0") | ||||||
|  |             # Would have to update setup.py with torch>=1.12.0 | ||||||
|  |             # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 | ||||||
|  |             # below is the current alternative. | ||||||
|  |             if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): | ||||||
|  |                 raise ValueError("FSDP requires PyTorch >= 1.12.0") | ||||||
|  |  | ||||||
|  |             from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy | ||||||
|  |  | ||||||
|  |             if ExtendedFSDPOption.FULL_SHARD in args.fsdp: | ||||||
|  |                 self.fsdp = ShardingStrategy.FULL_SHARD | ||||||
|  |             elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp: | ||||||
|  |                 self.fsdp = ShardingStrategy.SHARD_GRAD_OP | ||||||
|  |             elif ExtendedFSDPOption.NO_SHARD in args.fsdp: | ||||||
|  |                 self.fsdp = ShardingStrategy.NO_SHARD | ||||||
|  |             # extention starts here | ||||||
|  |             elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp: | ||||||
|  |                 self.fsdp = ShardingStrategy.HYBRID_SHARD | ||||||
|  |             elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp: | ||||||
|  |                 self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2 | ||||||
|  |             # extention ends here | ||||||
|  |  | ||||||
|  |             self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE | ||||||
|  |             # modification starts here | ||||||
|  |             if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post": | ||||||
|  |                 self.backward_prefetch = BackwardPrefetch.BACKWARD_POST | ||||||
|  |             # modification ends here | ||||||
|  |  | ||||||
|  |             self.forward_prefetch = False | ||||||
|  |             # modification starts here | ||||||
|  |             if self.args.fsdp_config.get("forward_prefetch", False): | ||||||
|  |             # modification ends here | ||||||
|  |                 self.forward_prefetch = True | ||||||
|  |                  | ||||||
|  |             self.limit_all_gathers = False | ||||||
|  |             if self.args.fsdp_config.get("limit_all_gathers", False): | ||||||
|  |                 self.limit_all_gathers = True | ||||||
|  |  | ||||||
|  |         # one place to sort out whether to place the model on device or not | ||||||
|  |         # postpone switching model to cuda when: | ||||||
|  |         # 1. MP - since we are trying to fit a much bigger than 1 gpu model | ||||||
|  |         # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, | ||||||
|  |         #    and we only use deepspeed for training at the moment | ||||||
|  |         # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first | ||||||
|  |         # 4. Sharded DDP - same as MP | ||||||
|  |         # 5. FSDP - same as MP | ||||||
|  |         self.place_model_on_device = args.place_model_on_device | ||||||
|  |         if ( | ||||||
|  |             self.is_model_parallel | ||||||
|  |             or self.is_deepspeed_enabled | ||||||
|  |             or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) | ||||||
|  |             or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) | ||||||
|  |             or (self.fsdp is not None) | ||||||
|  |             or self.is_fsdp_enabled | ||||||
|  |         ): | ||||||
|  |             self.place_model_on_device = False | ||||||
|  |  | ||||||
|  |         default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) | ||||||
|  |         self.data_collator = data_collator if data_collator is not None else default_collator | ||||||
|  |         self.train_dataset = train_dataset | ||||||
|  |         self.eval_dataset = eval_dataset | ||||||
|  |         self.tokenizer = tokenizer | ||||||
|  |  | ||||||
|  |         # Quantized models doesn't support `.to` operation. | ||||||
|  |         if self.place_model_on_device and not getattr(model, "is_quantized", False): | ||||||
|  |             self._move_model_to_device(model, args.device) | ||||||
|  |  | ||||||
|  |         # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs | ||||||
|  |         if self.is_model_parallel: | ||||||
|  |             self.args._n_gpu = 1 | ||||||
|  |  | ||||||
|  |         # later use `self.model is self.model_wrapped` to check if it's wrapped or not | ||||||
|  |         self.model_wrapped = model | ||||||
|  |         self.model = model | ||||||
|  |  | ||||||
|  |         self.compute_metrics = compute_metrics | ||||||
|  |         self.preprocess_logits_for_metrics = preprocess_logits_for_metrics | ||||||
|  |         self.optimizer, self.lr_scheduler = optimizers | ||||||
|  |         if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 "Passing a `model_init` is incompatible with providing the `optimizers` argument. " | ||||||
|  |                 "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||||||
|  |             ) | ||||||
|  |         if is_torch_tpu_available() and self.optimizer is not None: | ||||||
|  |             for param in self.model.parameters(): | ||||||
|  |                 model_device = param.device | ||||||
|  |                 break | ||||||
|  |             for param_group in self.optimizer.param_groups: | ||||||
|  |                 if len(param_group["params"]) > 0: | ||||||
|  |                     optimizer_device = param_group["params"][0].device | ||||||
|  |                     break | ||||||
|  |             if model_device != optimizer_device: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "The model and the optimizer parameters are not on the same device, which probably means you" | ||||||
|  |                     " created an optimizer around your model **before** putting on the device and passing it to the" | ||||||
|  |                     " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" | ||||||
|  |                     " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." | ||||||
|  |                 ) | ||||||
|  |         if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( | ||||||
|  |             self.optimizer is not None or self.lr_scheduler is not None | ||||||
|  |         ): | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." | ||||||
|  |                 "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||||||
|  |             ) | ||||||
|  |         default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | ||||||
|  |         callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | ||||||
|  |         self.callback_handler = CallbackHandler( | ||||||
|  |             callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler | ||||||
|  |         ) | ||||||
|  |         self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) | ||||||
|  |  | ||||||
|  |         # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. | ||||||
|  |         self._loggers_initialized = False | ||||||
|  |  | ||||||
|  |         # Create clone of distant repo and output directory if needed | ||||||
|  |         if self.args.push_to_hub: | ||||||
|  |             self.init_git_repo(at_init=True) | ||||||
|  |             # In case of pull, we need to make sure every process has the latest. | ||||||
|  |             if is_torch_tpu_available(): | ||||||
|  |                 xm.rendezvous("init git repo") | ||||||
|  |             elif args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||||
|  |                 dist.barrier() | ||||||
|  |  | ||||||
|  |         if self.args.should_save: | ||||||
|  |             os.makedirs(self.args.output_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |         if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): | ||||||
|  |             raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") | ||||||
|  |  | ||||||
|  |         if args.max_steps > 0: | ||||||
|  |             logger.info("max_steps is given, it will override any value given in num_train_epochs") | ||||||
|  |  | ||||||
|  |         if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The train_dataset does not implement __len__, max_steps has to be specified. " | ||||||
|  |                 "The number of steps needs to be known in advance for the learning rate scheduler." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             train_dataset is not None | ||||||
|  |             and isinstance(train_dataset, torch.utils.data.IterableDataset) | ||||||
|  |             and args.group_by_length | ||||||
|  |         ): | ||||||
|  |             raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") | ||||||
|  |  | ||||||
|  |         self._signature_columns = None | ||||||
|  |  | ||||||
|  |         # Mixed precision setup | ||||||
|  |         self.use_apex = False | ||||||
|  |         self.use_cuda_amp = False | ||||||
|  |         self.use_cpu_amp = False | ||||||
|  |  | ||||||
|  |         # Mixed precision setup for SageMaker Model Parallel | ||||||
|  |         if is_sagemaker_mp_enabled(): | ||||||
|  |             # BF16 + model parallelism in SageMaker: currently not supported, raise an error | ||||||
|  |             if args.bf16: | ||||||
|  |                 raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") | ||||||
|  |  | ||||||
|  |             if IS_SAGEMAKER_MP_POST_1_10: | ||||||
|  |                 # When there's mismatch between SMP config and trainer argument, use SMP config as truth | ||||||
|  |                 if args.fp16 != smp.state.cfg.fp16: | ||||||
|  |                     logger.warning( | ||||||
|  |                         f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," | ||||||
|  |                         f"but FP16 provided in trainer argument is {args.fp16}," | ||||||
|  |                         f"setting to {smp.state.cfg.fp16}" | ||||||
|  |                     ) | ||||||
|  |                     args.fp16 = smp.state.cfg.fp16 | ||||||
|  |             else: | ||||||
|  |                 # smp < 1.10 does not support fp16 in trainer. | ||||||
|  |                 if hasattr(smp.state.cfg, "fp16"): | ||||||
|  |                     logger.warning( | ||||||
|  |                         f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " | ||||||
|  |                         "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |         if (args.fp16 or args.bf16) and self.sharded_ddp is not None: | ||||||
|  |             if args.half_precision_backend == "auto": | ||||||
|  |                 if args.device == torch.device("cpu"): | ||||||
|  |                     if args.fp16: | ||||||
|  |                         raise ValueError("Tried to use `fp16` but it is not supported on cpu") | ||||||
|  |                     else: | ||||||
|  |                         args.half_precision_backend = "cpu_amp" | ||||||
|  |                 else: | ||||||
|  |                     args.half_precision_backend = "cuda_amp" | ||||||
|  |  | ||||||
|  |             logger.info(f"Using {args.half_precision_backend} half precision backend") | ||||||
|  |  | ||||||
|  |         self.do_grad_scaling = False | ||||||
|  |         if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): | ||||||
|  |             # deepspeed and SageMaker Model Parallel manage their own half precision | ||||||
|  |             if self.sharded_ddp is not None: | ||||||
|  |                 if args.half_precision_backend == "cuda_amp": | ||||||
|  |                     self.use_cuda_amp = True | ||||||
|  |                     self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 | ||||||
|  |                     #  bf16 does not need grad scaling | ||||||
|  |                     self.do_grad_scaling = self.amp_dtype == torch.float16 | ||||||
|  |                     if self.do_grad_scaling: | ||||||
|  |                         if self.sharded_ddp is not None: | ||||||
|  |                             self.scaler = ShardedGradScaler() | ||||||
|  |                         elif self.fsdp is not None: | ||||||
|  |                             from torch.distributed.fsdp.sharded_grad_scaler import ( | ||||||
|  |                                 ShardedGradScaler as FSDPShardedGradScaler, | ||||||
|  |                             ) | ||||||
|  |  | ||||||
|  |                             self.scaler = FSDPShardedGradScaler() | ||||||
|  |                         elif is_torch_tpu_available(): | ||||||
|  |                             from torch_xla.amp import GradScaler | ||||||
|  |  | ||||||
|  |                             self.scaler = GradScaler() | ||||||
|  |                         else: | ||||||
|  |                             self.scaler = torch.cuda.amp.GradScaler() | ||||||
|  |                 elif args.half_precision_backend == "cpu_amp": | ||||||
|  |                     self.use_cpu_amp = True | ||||||
|  |                     self.amp_dtype = torch.bfloat16 | ||||||
|  |             elif args.half_precision_backend == "apex": | ||||||
|  |                 if not is_apex_available(): | ||||||
|  |                     raise ImportError( | ||||||
|  |                         "Using FP16 with APEX but APEX is not installed, please refer to" | ||||||
|  |                         " https://www.github.com/nvidia/apex." | ||||||
|  |                     ) | ||||||
|  |                 self.use_apex = True | ||||||
|  |  | ||||||
|  |         # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. | ||||||
|  |         if ( | ||||||
|  |             is_sagemaker_mp_enabled() | ||||||
|  |             and self.use_cuda_amp | ||||||
|  |             and args.max_grad_norm is not None | ||||||
|  |             and args.max_grad_norm > 0 | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " | ||||||
|  |                 "along 'max_grad_norm': 0 in your hyperparameters." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # Label smoothing | ||||||
|  |         if self.args.label_smoothing_factor != 0: | ||||||
|  |             self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) | ||||||
|  |         else: | ||||||
|  |             self.label_smoother = None | ||||||
|  |  | ||||||
|  |         self.state = TrainerState( | ||||||
|  |             is_local_process_zero=self.is_local_process_zero(), | ||||||
|  |             is_world_process_zero=self.is_world_process_zero(), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.control = TrainerControl() | ||||||
|  |         # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then | ||||||
|  |         # returned to 0 every time flos need to be logged | ||||||
|  |         self.current_flos = 0 | ||||||
|  |         self.hp_search_backend = None | ||||||
|  |         self.use_tune_checkpoints = False | ||||||
|  |         default_label_names = find_labels(self.model.__class__) | ||||||
|  |         self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | ||||||
|  |         self.can_return_loss = can_return_loss(self.model.__class__) | ||||||
|  |         self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) | ||||||
|  |  | ||||||
|  |         # Internal variables to help with automatic batch size reduction | ||||||
|  |         self._train_batch_size = args.train_batch_size | ||||||
|  |         self._created_lr_scheduler = False | ||||||
|  |  | ||||||
|  |         # very last | ||||||
|  |         self._memory_tracker.stop_and_update_metrics() | ||||||
|  |  | ||||||
|  |         # torch.compile | ||||||
|  |         if args.torch_compile and not is_torch_compile_available(): | ||||||
|  |             raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") | ||||||
|  |          | ||||||
|  |         # finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict` | ||||||
|  |         enable_low_gpu_full_post_state_dict_hook() | ||||||
|  |          | ||||||
|  |     def _wrap_model(self, model, training=True, dataloader=None): | ||||||
|  |         if self.args.use_ipex: | ||||||
|  |             dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 | ||||||
|  |             model = self.ipex_optimize_model(model, training, dtype=dtype) | ||||||
|  |  | ||||||
|  |         if is_sagemaker_mp_enabled(): | ||||||
|  |             # Wrapping the base model twice in a DistributedModel will raise an error. | ||||||
|  |             if isinstance(self.model_wrapped, smp.model.DistributedModel): | ||||||
|  |                 return self.model_wrapped | ||||||
|  |             return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) | ||||||
|  |  | ||||||
|  |         # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again | ||||||
|  |         if unwrap_model(model) is not model: | ||||||
|  |             return model | ||||||
|  |  | ||||||
|  |         # Mixed precision training with apex (torch < 1.6) | ||||||
|  |         if self.use_apex and training: | ||||||
|  |             model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) | ||||||
|  |  | ||||||
|  |         # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP | ||||||
|  |         if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): | ||||||
|  |             model = nn.DataParallel(model) | ||||||
|  |  | ||||||
|  |         if self.args.jit_mode_eval: | ||||||
|  |             start_time = time.time() | ||||||
|  |             model = self.torch_jit_model_eval(model, dataloader, training) | ||||||
|  |             self.jit_compilation_time = round(time.time() - start_time, 4) | ||||||
|  |  | ||||||
|  |         # Note: in torch.distributed mode, there's no point in wrapping the model | ||||||
|  |         # inside a DistributedDataParallel as we'll be under `no_grad` anyways. | ||||||
|  |         if not training: | ||||||
|  |             return model | ||||||
|  |  | ||||||
|  |         # Distributed training (should be after apex fp16 initialization) | ||||||
|  |         if self.sharded_ddp is not None: | ||||||
|  |             # Sharded DDP! | ||||||
|  |             if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||||||
|  |                 model = ShardedDDP(model, self.optimizer) | ||||||
|  |             else: | ||||||
|  |                 mixed_precision = self.args.fp16 or self.args.bf16 | ||||||
|  |                 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp | ||||||
|  |                 zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 | ||||||
|  |                 # XXX: Breaking the self.model convention but I see no way around it for now. | ||||||
|  |                 if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: | ||||||
|  |                     model = auto_wrap(model) | ||||||
|  |                 self.model = model = FullyShardedDDP( | ||||||
|  |                     model, | ||||||
|  |                     mixed_precision=mixed_precision, | ||||||
|  |                     reshard_after_forward=zero_3, | ||||||
|  |                     cpu_offload=cpu_offload, | ||||||
|  |                 ).to(self.args.device) | ||||||
|  |         # Distributed training using PyTorch FSDP | ||||||
|  |         elif self.fsdp is not None: | ||||||
|  |             # fix starts here | ||||||
|  |             if not self.args.fsdp_config["xla"]: | ||||||
|  |                 # PyTorch FSDP! | ||||||
|  |                 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision | ||||||
|  |                 from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP | ||||||
|  |                 from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy | ||||||
|  |                 import torch.distributed.fsdp._traversal_utils as traversal_utils | ||||||
|  |  | ||||||
|  |                 if FSDPOption.OFFLOAD in self.args.fsdp: | ||||||
|  |                     cpu_offload = CPUOffload(offload_params=True) | ||||||
|  |                 else: | ||||||
|  |                     cpu_offload = CPUOffload(offload_params=False) | ||||||
|  |  | ||||||
|  |                 auto_wrap_policy = None | ||||||
|  |  | ||||||
|  |                 if FSDPOption.AUTO_WRAP in self.args.fsdp: | ||||||
|  |                     if self.args.fsdp_config["fsdp_min_num_params"] > 0: | ||||||
|  |                         auto_wrap_policy = functools.partial( | ||||||
|  |                             size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] | ||||||
|  |                         ) | ||||||
|  |                     elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||||
|  |                         transformer_cls_to_wrap = set() | ||||||
|  |                         for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: | ||||||
|  |                             transformer_cls = get_module_class_from_name(model, layer_class) | ||||||
|  |                             if transformer_cls is None: | ||||||
|  |                                 raise Exception("Could not find the transformer layer class to wrap in the model.") | ||||||
|  |                             else: | ||||||
|  |                                 transformer_cls_to_wrap.add(transformer_cls) | ||||||
|  |                         auto_wrap_policy = functools.partial( | ||||||
|  |                             transformer_auto_wrap_policy, | ||||||
|  |                             # Transformer layer class to wrap | ||||||
|  |                             transformer_layer_cls=transformer_cls_to_wrap, | ||||||
|  |                         ) | ||||||
|  |                 mixed_precision_policy = None | ||||||
|  |                 dtype = None | ||||||
|  |                 if self.args.fp16: | ||||||
|  |                     dtype = torch.float16 | ||||||
|  |                 elif self.args.bf16: | ||||||
|  |                     dtype = torch.bfloat16 | ||||||
|  |                 if dtype is not None: | ||||||
|  |                     mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) | ||||||
|  |                 if type(model) != FSDP: | ||||||
|  |                     # XXX: Breaking the self.model convention but I see no way around it for now. | ||||||
|  |                     signature = inspect.signature(FSDP.__init__).parameters.keys() | ||||||
|  |                     kwargs = {} | ||||||
|  |                     for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: | ||||||
|  |                         if arg in signature: | ||||||
|  |                             kwargs[arg] = getattr(self, arg) | ||||||
|  |                     self.model = model = FSDP( | ||||||
|  |                         model, | ||||||
|  |                         sharding_strategy=self.fsdp, | ||||||
|  |                         cpu_offload=cpu_offload, | ||||||
|  |                         auto_wrap_policy=auto_wrap_policy, | ||||||
|  |                         mixed_precision=mixed_precision_policy, | ||||||
|  |                         device_id=self.args.device, | ||||||
|  |                         **kwargs, | ||||||
|  |                     ) | ||||||
|  |                      | ||||||
|  |                 for submodule in traversal_utils._get_fsdp_states(model): | ||||||
|  |                     print(submodule._state_dict_type, submodule._state_dict_config) | ||||||
|  |                     break | ||||||
|  |             # fix ends here | ||||||
|  |             else: | ||||||
|  |                 try: | ||||||
|  |                     from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP | ||||||
|  |                     from torch_xla.distributed.fsdp import checkpoint_module | ||||||
|  |                     from torch_xla.distributed.fsdp.wrap import ( | ||||||
|  |                         size_based_auto_wrap_policy, | ||||||
|  |                         transformer_auto_wrap_policy, | ||||||
|  |                     ) | ||||||
|  |                 except ImportError: | ||||||
|  |                     raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") | ||||||
|  |                 auto_wrap_policy = None | ||||||
|  |                 auto_wrapper_callable = None | ||||||
|  |                 if self.args.fsdp_config["fsdp_min_num_params"] > 0: | ||||||
|  |                     auto_wrap_policy = functools.partial( | ||||||
|  |                         size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] | ||||||
|  |                     ) | ||||||
|  |                 elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||||
|  |                     transformer_cls_to_wrap = set() | ||||||
|  |                     for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: | ||||||
|  |                         transformer_cls = get_module_class_from_name(model, layer_class) | ||||||
|  |                         if transformer_cls is None: | ||||||
|  |                             raise Exception("Could not find the transformer layer class to wrap in the model.") | ||||||
|  |                         else: | ||||||
|  |                             transformer_cls_to_wrap.add(transformer_cls) | ||||||
|  |                     auto_wrap_policy = functools.partial( | ||||||
|  |                         transformer_auto_wrap_policy, | ||||||
|  |                         # Transformer layer class to wrap | ||||||
|  |                         transformer_layer_cls=transformer_cls_to_wrap, | ||||||
|  |                     ) | ||||||
|  |                 fsdp_kwargs = self.args.xla_fsdp_config | ||||||
|  |                 if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: | ||||||
|  |                     # Apply gradient checkpointing to auto-wrapped sub-modules if specified | ||||||
|  |                     def auto_wrapper_callable(m, *args, **kwargs): | ||||||
|  |                         return FSDP(checkpoint_module(m), *args, **kwargs) | ||||||
|  |  | ||||||
|  |                 # Wrap the base model with an outer FSDP wrapper | ||||||
|  |                 self.model = model = FSDP( | ||||||
|  |                     model, | ||||||
|  |                     auto_wrap_policy=auto_wrap_policy, | ||||||
|  |                     auto_wrapper_callable=auto_wrapper_callable, | ||||||
|  |                     **fsdp_kwargs, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |                 # Patch `xm.optimizer_step` should not reduce gradients in this case, | ||||||
|  |                 # as FSDP does not need gradient reduction over sharded parameters. | ||||||
|  |                 def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): | ||||||
|  |                     loss = optimizer.step(**optimizer_args) | ||||||
|  |                     if barrier: | ||||||
|  |                         xm.mark_step() | ||||||
|  |                     return loss | ||||||
|  |  | ||||||
|  |                 xm.optimizer_step = patched_optimizer_step | ||||||
|  |         elif is_sagemaker_dp_enabled(): | ||||||
|  |             model = nn.parallel.DistributedDataParallel( | ||||||
|  |                 model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] | ||||||
|  |             ) | ||||||
|  |         elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||||
|  |             if is_torch_neuroncore_available(): | ||||||
|  |                 return model | ||||||
|  |             kwargs = {} | ||||||
|  |             if self.args.ddp_find_unused_parameters is not None: | ||||||
|  |                 kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters | ||||||
|  |             elif isinstance(model, PreTrainedModel): | ||||||
|  |                 # find_unused_parameters breaks checkpointing as per | ||||||
|  |                 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 | ||||||
|  |                 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing | ||||||
|  |             else: | ||||||
|  |                 kwargs["find_unused_parameters"] = True | ||||||
|  |  | ||||||
|  |             if self.args.ddp_bucket_cap_mb is not None: | ||||||
|  |                 kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb | ||||||
|  |  | ||||||
|  |             if self.args.ddp_broadcast_buffers is not None: | ||||||
|  |                 kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers | ||||||
|  |  | ||||||
|  |             self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) | ||||||
|  |  | ||||||
|  |         return model | ||||||
|  |      | ||||||
|  |     def get_batch_sampler(self, dataset=None): | ||||||
|  |         if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: | ||||||
|  |             dataset = dataset if dataset is not None else self.train_dataset | ||||||
|  |             try: | ||||||
|  |                 batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context | ||||||
|  |             except: | ||||||
|  |                 # raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`") | ||||||
|  |                 batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context | ||||||
|  |             model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None | ||||||
|  |             lengths = LengthGroupedSampler( | ||||||
|  |                     batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size` | ||||||
|  |                     dataset=dataset, | ||||||
|  |                     model_input_name=model_input_name | ||||||
|  |             ).lengths | ||||||
|  |             seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | ||||||
|  |             batch_sampler = FFDDistributedBatchSampler( | ||||||
|  |                 batch_max_length=batch_max_len, | ||||||
|  |                 lengths=np.array(lengths), | ||||||
|  |                 seed=seed | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             return batch_sampler | ||||||
|  |      | ||||||
|  |         return None | ||||||
|  |      | ||||||
|  |     def get_train_dataloader(self) -> DataLoader: | ||||||
|  |         if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: | ||||||
|  |             if self.train_dataset is None: | ||||||
|  |                 raise ValueError("Trainer: training requires a train_dataset.") | ||||||
|  |  | ||||||
|  |             train_dataset = self.train_dataset | ||||||
|  |             data_collator = self.data_collator | ||||||
|  |             if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | ||||||
|  |                 train_dataset = self._remove_unused_columns(train_dataset, description="training") | ||||||
|  |             else: | ||||||
|  |                 data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | ||||||
|  |                  | ||||||
|  |             batch_sampler = self.get_batch_sampler(train_dataset) | ||||||
|  |              | ||||||
|  |             dataloader = DataLoader( | ||||||
|  |                 train_dataset, | ||||||
|  |                 batch_sampler=batch_sampler, | ||||||
|  |                 drop_last=self.args.dataloader_drop_last, | ||||||
|  |                 collate_fn=data_collator | ||||||
|  |             ) | ||||||
|  |             # return self.accelerator.prepare(dataloader) | ||||||
|  |             return dataloader | ||||||
|  |              | ||||||
|  |         return super().get_train_dataloader() | ||||||
|  |      | ||||||
|  |     def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): | ||||||
|  |         """ | ||||||
|  |         Will save the model, so you can reload it using `from_pretrained()`. | ||||||
|  |  | ||||||
|  |         Will only save from the main process. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if output_dir is None: | ||||||
|  |             output_dir = self.args.output_dir | ||||||
|  |  | ||||||
|  |         if is_torch_tpu_available(): | ||||||
|  |             self._save_tpu(output_dir) | ||||||
|  |         elif is_sagemaker_mp_enabled(): | ||||||
|  |             # Calling the state_dict needs to be done on the wrapped model and on all processes. | ||||||
|  |             os.makedirs(output_dir, exist_ok=True) | ||||||
|  |             state_dict = self.model_wrapped.state_dict() | ||||||
|  |             if self.args.should_save: | ||||||
|  |                 self._save(output_dir, state_dict=state_dict) | ||||||
|  |             if IS_SAGEMAKER_MP_POST_1_10: | ||||||
|  |                 # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 | ||||||
|  |                 Path(os.path.join(output_dir, "user_content.pt")).touch() | ||||||
|  |         elif ( | ||||||
|  |             ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp | ||||||
|  |             or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp | ||||||
|  |             or self.fsdp is not None | ||||||
|  |             or self.is_fsdp_enabled | ||||||
|  |         ): | ||||||
|  |             state_dict = self.model.state_dict() | ||||||
|  |             if self.args.should_save: | ||||||
|  |                 self._save(output_dir, state_dict=state_dict) | ||||||
|  |             # modification starts here | ||||||
|  |             if self.is_fsdp_enabled and self.args.save_with_fsdp: | ||||||
|  |                 save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) | ||||||
|  |             # modification ends here | ||||||
|  |  | ||||||
|  |         elif self.is_deepspeed_enabled: | ||||||
|  |             # this takes care of everything as long as we aren't under zero3 | ||||||
|  |             if version.parse(accelerate_version) <= version.parse("0.20.3"): | ||||||
|  |                 raise ValueError("Install Accelerate from main branch") | ||||||
|  |             try: | ||||||
|  |                 state_dict = self.accelerator.get_state_dict(self.deepspeed) | ||||||
|  |                 if self.args.should_save: | ||||||
|  |                     self._save(output_dir, state_dict=state_dict) | ||||||
|  |             except ValueError: | ||||||
|  |                 logger.warning( | ||||||
|  |                     " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" | ||||||
|  |                     " zero_to_fp32.py to recover weights" | ||||||
|  |                 ) | ||||||
|  |                 self.model_wrapped.save_checkpoint(output_dir) | ||||||
|  |  | ||||||
|  |         elif self.args.should_save: | ||||||
|  |             self._save(output_dir) | ||||||
|  |  | ||||||
|  |         # Push to the Hub when `save_model` is called by the user. | ||||||
|  |         if self.args.push_to_hub and not _internal_call: | ||||||
|  |             self.push_to_hub(commit_message="Model save") | ||||||
|  |      | ||||||
|  |     def _save_checkpoint(self, model, trial, metrics=None): | ||||||
|  |         # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we | ||||||
|  |         # want to save except FullyShardedDDP. | ||||||
|  |         # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" | ||||||
|  |  | ||||||
|  |         # Save model checkpoint | ||||||
|  |         checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" | ||||||
|  |  | ||||||
|  |         if self.hp_search_backend is None and trial is None: | ||||||
|  |             self.store_flos() | ||||||
|  |  | ||||||
|  |         run_dir = self._get_output_dir(trial=trial) | ||||||
|  |         output_dir = os.path.join(run_dir, checkpoint_folder) | ||||||
|  |         self.save_model(output_dir, _internal_call=True) | ||||||
|  |         if self.is_deepspeed_enabled: | ||||||
|  |             # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed | ||||||
|  |             # config `stage3_gather_16bit_weights_on_model_save` is True | ||||||
|  |             self.model_wrapped.save_checkpoint(output_dir) | ||||||
|  |  | ||||||
|  |         # Save optimizer and scheduler | ||||||
|  |         if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||||||
|  |             self.optimizer.consolidate_state_dict() | ||||||
|  |  | ||||||
|  |         if self.fsdp or self.is_fsdp_enabled: | ||||||
|  |             if self.is_fsdp_enabled: | ||||||
|  |                 # modification starts here | ||||||
|  |                 if self.args.save_with_fsdp: | ||||||
|  |                     save_fsdp_optimizer( | ||||||
|  |                         self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir | ||||||
|  |                     ) | ||||||
|  |                 # modification ends here | ||||||
|  |             else: | ||||||
|  |                 # FSDP has a different interface for saving optimizer states. | ||||||
|  |                 # Needs to be called on all ranks to gather all states. | ||||||
|  |                 # full_optim_state_dict will be deprecated after Pytorch 2.2! | ||||||
|  |                 full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) | ||||||
|  |  | ||||||
|  |         if is_torch_tpu_available(): | ||||||
|  |             xm.rendezvous("saving_optimizer_states") | ||||||
|  |             xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||||||
|  |             with warnings.catch_warnings(record=True) as caught_warnings: | ||||||
|  |                 xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||||
|  |                 reissue_pt_warnings(caught_warnings) | ||||||
|  |         elif is_sagemaker_mp_enabled(): | ||||||
|  |             opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) | ||||||
|  |             smp.barrier() | ||||||
|  |             if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: | ||||||
|  |                 smp.save( | ||||||
|  |                     opt_state_dict, | ||||||
|  |                     os.path.join(output_dir, OPTIMIZER_NAME), | ||||||
|  |                     partial=True, | ||||||
|  |                     v3=smp.state.cfg.shard_optimizer_state, | ||||||
|  |                 ) | ||||||
|  |             if self.args.should_save: | ||||||
|  |                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||||
|  |                     torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||||
|  |                 reissue_pt_warnings(caught_warnings) | ||||||
|  |                 if self.do_grad_scaling: | ||||||
|  |                     torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||||||
|  |         elif self.args.should_save and not self.is_deepspeed_enabled: | ||||||
|  |             # deepspeed.save_checkpoint above saves model/optim/sched | ||||||
|  |             if self.fsdp and not self.is_fsdp_enabled: | ||||||
|  |                 torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) | ||||||
|  |             else: | ||||||
|  |                 torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||||||
|  |  | ||||||
|  |             with warnings.catch_warnings(record=True) as caught_warnings: | ||||||
|  |                 torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||||
|  |             reissue_pt_warnings(caught_warnings) | ||||||
|  |             if self.do_grad_scaling: | ||||||
|  |                 torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||||||
|  |  | ||||||
|  |         # Determine the new best metric / best model checkpoint | ||||||
|  |         if metrics is not None and self.args.metric_for_best_model is not None: | ||||||
|  |             metric_to_check = self.args.metric_for_best_model | ||||||
|  |             if not metric_to_check.startswith("eval_"): | ||||||
|  |                 metric_to_check = f"eval_{metric_to_check}" | ||||||
|  |             metric_value = metrics[metric_to_check] | ||||||
|  |  | ||||||
|  |             operator = np.greater if self.args.greater_is_better else np.less | ||||||
|  |             if ( | ||||||
|  |                 self.state.best_metric is None | ||||||
|  |                 or self.state.best_model_checkpoint is None | ||||||
|  |                 or operator(metric_value, self.state.best_metric) | ||||||
|  |             ): | ||||||
|  |                 self.state.best_metric = metric_value | ||||||
|  |                 self.state.best_model_checkpoint = output_dir | ||||||
|  |  | ||||||
|  |         # Save the Trainer state | ||||||
|  |         if self.args.should_save: | ||||||
|  |             self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | ||||||
|  |  | ||||||
|  |         # Save RNG state in non-distributed training | ||||||
|  |         rng_states = { | ||||||
|  |             "python": random.getstate(), | ||||||
|  |             "numpy": np.random.get_state(), | ||||||
|  |             "cpu": torch.random.get_rng_state(), | ||||||
|  |         } | ||||||
|  |         if torch.cuda.is_available(): | ||||||
|  |             if self.args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||||
|  |                 # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) | ||||||
|  |                 rng_states["cuda"] = torch.cuda.random.get_rng_state_all() | ||||||
|  |             else: | ||||||
|  |                 rng_states["cuda"] = torch.cuda.random.get_rng_state() | ||||||
|  |  | ||||||
|  |         if is_torch_tpu_available(): | ||||||
|  |             rng_states["xla"] = xm.get_rng_state() | ||||||
|  |  | ||||||
|  |         # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may | ||||||
|  |         # not yet exist. | ||||||
|  |         os.makedirs(output_dir, exist_ok=True) | ||||||
|  |  | ||||||
|  |         if self.args.world_size <= 1: | ||||||
|  |             torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) | ||||||
|  |         else: | ||||||
|  |             torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) | ||||||
|  |  | ||||||
|  |         if self.args.push_to_hub: | ||||||
|  |             self._push_from_checkpoint(output_dir) | ||||||
|  |  | ||||||
|  |         # Maybe delete some older checkpoints. | ||||||
|  |         if self.args.should_save: | ||||||
|  |             self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) | ||||||
|  |              | ||||||
|  |     def _load_optimizer_and_scheduler(self, checkpoint): | ||||||
|  |         """If optimizer and scheduler states exist, load them.""" | ||||||
|  |         if checkpoint is None: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if self.is_deepspeed_enabled: | ||||||
|  |             # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         checkpoint_file_exists = ( | ||||||
|  |             glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") | ||||||
|  |             if is_sagemaker_mp_enabled() | ||||||
|  |             else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) | ||||||
|  |         ) | ||||||
|  |         if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): | ||||||
|  |             # Load in optimizer and scheduler states | ||||||
|  |             if is_torch_tpu_available(): | ||||||
|  |                 # On TPU we have to take some extra precautions to properly load the states on the right device. | ||||||
|  |                 optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") | ||||||
|  |                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||||
|  |                     lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") | ||||||
|  |                 reissue_pt_warnings(caught_warnings) | ||||||
|  |  | ||||||
|  |                 xm.send_cpu_data_to_device(optimizer_state, self.args.device) | ||||||
|  |                 xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) | ||||||
|  |  | ||||||
|  |                 self.optimizer.load_state_dict(optimizer_state) | ||||||
|  |                 self.lr_scheduler.load_state_dict(lr_scheduler_state) | ||||||
|  |             else: | ||||||
|  |                 if is_sagemaker_mp_enabled(): | ||||||
|  |                     if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): | ||||||
|  |                         # Optimizer checkpoint was saved with smp >= 1.10 | ||||||
|  |                         def opt_load_hook(mod, opt): | ||||||
|  |                             opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | ||||||
|  |  | ||||||
|  |                     else: | ||||||
|  |                         # Optimizer checkpoint was saved with smp < 1.10 | ||||||
|  |                         def opt_load_hook(mod, opt): | ||||||
|  |                             if IS_SAGEMAKER_MP_POST_1_10: | ||||||
|  |                                 opt.load_state_dict( | ||||||
|  |                                     smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) | ||||||
|  |                                 ) | ||||||
|  |                             else: | ||||||
|  |                                 opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | ||||||
|  |  | ||||||
|  |                     self.model_wrapped.register_post_step_hook(opt_load_hook) | ||||||
|  |                 else: | ||||||
|  |                     # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. | ||||||
|  |                     # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more | ||||||
|  |                     # likely to get OOM on CPU (since we load num_gpu times the optimizer state | ||||||
|  |                     map_location = self.args.device if self.args.world_size > 1 else "cpu" | ||||||
|  |                     if self.fsdp or self.is_fsdp_enabled: | ||||||
|  |                         # modification starts here | ||||||
|  |                         if self.is_fsdp_enabled and self.args.save_with_fsdp: | ||||||
|  |                             load_fsdp_optimizer( | ||||||
|  |                                 self.accelerator.state.fsdp_plugin, | ||||||
|  |                                 self.accelerator, | ||||||
|  |                                 self.optimizer, | ||||||
|  |                                 self.model, | ||||||
|  |                                 checkpoint, | ||||||
|  |                             ) | ||||||
|  |                         elif not self.is_fsdp_enabled: | ||||||
|  |                             full_osd = None | ||||||
|  |                             # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it | ||||||
|  |                             if self.args.process_index == 0: | ||||||
|  |                                 full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) | ||||||
|  |                             # call scatter_full_optim_state_dict on all ranks | ||||||
|  |                             sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) | ||||||
|  |                             self.optimizer.load_state_dict(sharded_osd) | ||||||
|  |                         else: | ||||||
|  |                             self.optimizer.load_state_dict( | ||||||
|  |                             torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||||||
|  |                         ) | ||||||
|  |                         # modification ends here | ||||||
|  |                     else: | ||||||
|  |                         self.optimizer.load_state_dict( | ||||||
|  |                             torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||||||
|  |                         ) | ||||||
|  |                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||||
|  |                     self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) | ||||||
|  |                 reissue_pt_warnings(caught_warnings) | ||||||
|  |                 if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): | ||||||
|  |                     self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) | ||||||
|  |                      | ||||||
							
								
								
									
										478
									
								
								train/trainers/fsdp_training_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										478
									
								
								train/trainers/fsdp_training_args.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,478 @@ | |||||||
|  | import sys | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | import transformers | ||||||
|  | from transformers.training_args import * | ||||||
|  |  | ||||||
|  | from .utils import ExtendedFSDPOption | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class FSDPTrainingArguments(transformers.TrainingArguments): | ||||||
|  |     # about data-efficient sampler | ||||||
|  |     use_ffd_sampler: bool = False | ||||||
|  |     model_avg_context: int = 2048 | ||||||
|  |      | ||||||
|  |     # about saving | ||||||
|  |     # if not save with fsdp, then must not load with fsdp | ||||||
|  |     save_with_fsdp: bool = False | ||||||
|  |      | ||||||
|  |     def __post_init__(self): | ||||||
|  |         # expand paths, if not os.makedirs("~/bar") will make directory | ||||||
|  |         # in the current directory instead of the actual home | ||||||
|  |         # see https://github.com/huggingface/transformers/issues/10628 | ||||||
|  |         if self.output_dir is not None: | ||||||
|  |             self.output_dir = os.path.expanduser(self.output_dir) | ||||||
|  |         if self.logging_dir is None and self.output_dir is not None: | ||||||
|  |             self.logging_dir = os.path.join(self.output_dir, default_logdir()) | ||||||
|  |         if self.logging_dir is not None: | ||||||
|  |             self.logging_dir = os.path.expanduser(self.logging_dir) | ||||||
|  |  | ||||||
|  |         if self.disable_tqdm is None: | ||||||
|  |             self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN | ||||||
|  |  | ||||||
|  |         if isinstance(self.evaluation_strategy, EvaluationStrategy): | ||||||
|  |             warnings.warn( | ||||||
|  |                 "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" | ||||||
|  |                 " of 🤗 Transformers. Use `IntervalStrategy` instead", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |             # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. | ||||||
|  |             self.evaluation_strategy = self.evaluation_strategy.value | ||||||
|  |  | ||||||
|  |         # if self.xpu_backend is not None: | ||||||
|  |         #     warnings.warn( | ||||||
|  |         #         "using `xpu_backend` is deprecated and will be removed in version 4.31" | ||||||
|  |         #         " of 🤗 Transformers. Use `ddp_backend` instead", | ||||||
|  |         #         FutureWarning, | ||||||
|  |         #     ) | ||||||
|  |         #     self.ddp_backend = self.xpu_backend | ||||||
|  |  | ||||||
|  |         self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) | ||||||
|  |         self.logging_strategy = IntervalStrategy(self.logging_strategy) | ||||||
|  |         self.save_strategy = IntervalStrategy(self.save_strategy) | ||||||
|  |         self.hub_strategy = HubStrategy(self.hub_strategy) | ||||||
|  |  | ||||||
|  |         self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) | ||||||
|  |         if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: | ||||||
|  |             self.do_eval = True | ||||||
|  |  | ||||||
|  |         # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero | ||||||
|  |         if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): | ||||||
|  |             if self.logging_steps > 0: | ||||||
|  |                 logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") | ||||||
|  |                 self.eval_steps = self.logging_steps | ||||||
|  |             else: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or" | ||||||
|  |                     " --logging_steps" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         # logging_steps must be non-zero for logging_strategy that is other than 'no' | ||||||
|  |         if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: | ||||||
|  |             raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") | ||||||
|  |  | ||||||
|  |         if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: | ||||||
|  |             if self.logging_steps != int(self.logging_steps): | ||||||
|  |                 raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") | ||||||
|  |             self.logging_steps = int(self.logging_steps) | ||||||
|  |         if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: | ||||||
|  |             if self.eval_steps != int(self.eval_steps): | ||||||
|  |                 raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") | ||||||
|  |             self.eval_steps = int(self.eval_steps) | ||||||
|  |         if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: | ||||||
|  |             if self.save_steps != int(self.save_steps): | ||||||
|  |                 raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") | ||||||
|  |             self.save_steps = int(self.save_steps) | ||||||
|  |  | ||||||
|  |         # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. | ||||||
|  |         if self.load_best_model_at_end: | ||||||
|  |             if self.evaluation_strategy != self.save_strategy: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " | ||||||
|  |                     f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" | ||||||
|  |                 ) | ||||||
|  |             if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: | ||||||
|  |                 if self.eval_steps < 1 or self.save_steps < 1: | ||||||
|  |                     if not (self.eval_steps < 1 and self.save_steps < 1): | ||||||
|  |                         raise ValueError( | ||||||
|  |                             "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " | ||||||
|  |                             "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps" | ||||||
|  |                             f"{self.save_steps} and eval_steps {self.eval_steps}." | ||||||
|  |                         ) | ||||||
|  |                     # Work around floating point precision issues | ||||||
|  |                     LARGE_MULTIPLIER = 1_000_000 | ||||||
|  |                     if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: | ||||||
|  |                         raise ValueError( | ||||||
|  |                             "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " | ||||||
|  |                             f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." | ||||||
|  |                         ) | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " | ||||||
|  |                     f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         safetensors_available = is_safetensors_available() | ||||||
|  |         if self.save_safetensors and not safetensors_available: | ||||||
|  |             raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!") | ||||||
|  |         if not self.save_safetensors and safetensors_available: | ||||||
|  |             logger.info( | ||||||
|  |                 f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " | ||||||
|  |                 f"Safetensors should be a preferred weights saving format due to security and performance reasons. " | ||||||
|  |                 f"If your model cannot be saved by safetensors please feel free to open an issue at " | ||||||
|  |                 f"https://github.com/huggingface/safetensors!" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU | ||||||
|  |         ) and self.metric_for_best_model is None: | ||||||
|  |             self.metric_for_best_model = "loss" | ||||||
|  |         if self.greater_is_better is None and self.metric_for_best_model is not None: | ||||||
|  |             self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] | ||||||
|  |         if self.run_name is None: | ||||||
|  |             self.run_name = self.output_dir | ||||||
|  |         if self.framework == "pt" and is_torch_available(): | ||||||
|  |             if self.fp16_backend and self.fp16_backend != "auto": | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||||
|  |                     " `half_precision_backend` instead", | ||||||
|  |                     FutureWarning, | ||||||
|  |                 ) | ||||||
|  |                 self.half_precision_backend = self.fp16_backend | ||||||
|  |  | ||||||
|  |             if self.bf16 or self.bf16_full_eval: | ||||||
|  |                 if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): | ||||||
|  |                     # cpu | ||||||
|  |                     raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") | ||||||
|  |                 elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): | ||||||
|  |                     # gpu | ||||||
|  |                     raise ValueError( | ||||||
|  |                         "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |         if self.fp16 and self.bf16: | ||||||
|  |             raise ValueError("At most one of fp16 and bf16 can be True, but not both") | ||||||
|  |  | ||||||
|  |         if self.fp16_full_eval and self.bf16_full_eval: | ||||||
|  |             raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") | ||||||
|  |  | ||||||
|  |         if self.bf16: | ||||||
|  |             if self.half_precision_backend == "apex": | ||||||
|  |                 raise ValueError( | ||||||
|  |                     " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use" | ||||||
|  |                     " `--half_precision_backend cuda_amp` instead" | ||||||
|  |                 ) | ||||||
|  |             if not (self.sharded_ddp == "" or not self.sharded_ddp): | ||||||
|  |                 raise ValueError("sharded_ddp is not supported with bf16") | ||||||
|  |  | ||||||
|  |         if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: | ||||||
|  |             if self.evaluation_strategy == IntervalStrategy.NO: | ||||||
|  |                 raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") | ||||||
|  |             if not is_torch_available(): | ||||||
|  |                 raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") | ||||||
|  |  | ||||||
|  |         self.optim = OptimizerNames(self.optim) | ||||||
|  |         if self.adafactor: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim" | ||||||
|  |                 " adafactor` instead", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |             self.optim = OptimizerNames.ADAFACTOR | ||||||
|  |         if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available(): | ||||||
|  |             if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): | ||||||
|  |                 raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") | ||||||
|  |             # there is a bug in fp16/AMP in pt-2.0.0 | ||||||
|  |             if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: | ||||||
|  |                 raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             self.framework == "pt" | ||||||
|  |             and is_torch_available() | ||||||
|  |             and (self.device.type != "cuda") | ||||||
|  |             and (get_xla_device_type(self.device) != "GPU") | ||||||
|  |             and (self.fp16 or self.fp16_full_eval) | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" | ||||||
|  |                 " (`--fp16_full_eval`) can only be used on CUDA devices." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             self.framework == "pt" | ||||||
|  |             and is_torch_available() | ||||||
|  |             and (self.device.type != "cuda") | ||||||
|  |             and (get_xla_device_type(self.device) != "GPU") | ||||||
|  |             and (get_xla_device_type(self.device) != "TPU") | ||||||
|  |             and (self.device.type != "cpu") | ||||||
|  |             and (self.bf16 or self.bf16_full_eval) | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" | ||||||
|  |                 " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if self.torchdynamo is not None: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||||
|  |                 " `torch_compile_backend` instead", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |             self.torch_compile_backend = self.torchdynamo | ||||||
|  |         if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: | ||||||
|  |             self.torch_compile = True | ||||||
|  |         if self.torch_compile and self.torch_compile_backend is None: | ||||||
|  |             self.torch_compile_backend = "inductor" | ||||||
|  |  | ||||||
|  |         # accelerate integration for torch compile | ||||||
|  |         if self.torch_compile: | ||||||
|  |             # set env vars for accelerate | ||||||
|  |             prefix = "ACCELERATE_DYNAMO_" | ||||||
|  |             os.environ[prefix + "BACKEND"] = self.torch_compile_backend | ||||||
|  |             if self.torch_compile_mode is not None: | ||||||
|  |                 os.environ[prefix + "MODE"] = self.torch_compile_mode | ||||||
|  |  | ||||||
|  |         if self.framework == "pt" and is_torch_available() and self.torch_compile: | ||||||
|  |             if is_torch_tf32_available(): | ||||||
|  |                 if self.tf32 is None and not self.fp16 or self.bf16: | ||||||
|  |                     logger.info( | ||||||
|  |                         "Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement" | ||||||
|  |                         " otherwise." | ||||||
|  |                     ) | ||||||
|  |                     torch.backends.cuda.matmul.allow_tf32 = True | ||||||
|  |             else: | ||||||
|  |                 logger.warning( | ||||||
|  |                     "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." | ||||||
|  |                 ) | ||||||
|  |         if self.framework == "pt" and is_torch_available() and self.tf32 is not None: | ||||||
|  |             if self.tf32: | ||||||
|  |                 if is_torch_tf32_available(): | ||||||
|  |                     torch.backends.cuda.matmul.allow_tf32 = True | ||||||
|  |                 else: | ||||||
|  |                     raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") | ||||||
|  |             else: | ||||||
|  |                 if is_torch_tf32_available(): | ||||||
|  |                     torch.backends.cuda.matmul.allow_tf32 = False | ||||||
|  |                 # no need to assert on else | ||||||
|  |  | ||||||
|  |         if self.report_to is None: | ||||||
|  |             logger.info( | ||||||
|  |                 "The default value for the training argument `--report_to` will change in v5 (from all installed " | ||||||
|  |                 "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as " | ||||||
|  |                 "now. You should start updating your code and make this info disappear :-)." | ||||||
|  |             ) | ||||||
|  |             self.report_to = "all" | ||||||
|  |         if self.report_to == "all" or self.report_to == ["all"]: | ||||||
|  |             # Import at runtime to avoid a circular import. | ||||||
|  |             from transformers.integrations import get_available_reporting_integrations | ||||||
|  |  | ||||||
|  |             self.report_to = get_available_reporting_integrations() | ||||||
|  |         elif self.report_to == "none" or self.report_to == ["none"]: | ||||||
|  |             self.report_to = [] | ||||||
|  |         elif not isinstance(self.report_to, list): | ||||||
|  |             self.report_to = [self.report_to] | ||||||
|  |  | ||||||
|  |         if self.warmup_ratio < 0 or self.warmup_ratio > 1: | ||||||
|  |             raise ValueError("warmup_ratio must lie in range [0,1]") | ||||||
|  |         elif self.warmup_ratio > 0 and self.warmup_steps > 0: | ||||||
|  |             logger.info( | ||||||
|  |                 "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio" | ||||||
|  |                 " during training" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if not (self.sharded_ddp == "" or not self.sharded_ddp): | ||||||
|  |             warnings.warn( | ||||||
|  |                 "using `sharded_ddp` is deprecated and will be removed in version 4.33" | ||||||
|  |                 " of 🤗 Transformers. Use `fsdp` instead", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |         if isinstance(self.sharded_ddp, bool): | ||||||
|  |             self.sharded_ddp = "simple" if self.sharded_ddp else "" | ||||||
|  |         if isinstance(self.sharded_ddp, str): | ||||||
|  |             self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] | ||||||
|  |         if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " | ||||||
|  |                 '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' | ||||||
|  |             ) | ||||||
|  |         elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp: | ||||||
|  |             raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") | ||||||
|  |         elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: | ||||||
|  |             raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") | ||||||
|  |  | ||||||
|  |         if isinstance(self.fsdp, bool): | ||||||
|  |             self.fsdp = "full_shard" if self.fsdp else "" | ||||||
|  |         if isinstance(self.fsdp, str): | ||||||
|  |             self.fsdp = [ExtendedFSDPOption(s) for s in self.fsdp.split()] | ||||||
|  |         if self.fsdp == [ExtendedFSDPOption.OFFLOAD]: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " | ||||||
|  |                 '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' | ||||||
|  |             ) | ||||||
|  |         elif ExtendedFSDPOption.FULL_SHARD in self.fsdp and ExtendedFSDPOption.SHARD_GRAD_OP in self.fsdp: | ||||||
|  |             raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") | ||||||
|  |  | ||||||
|  |         if self.fsdp_config is None: | ||||||
|  |             self.fsdp_config = {} | ||||||
|  |  | ||||||
|  |         if isinstance(self.fsdp_config, str): | ||||||
|  |             with io.open(self.fsdp_config, "r", encoding="utf-8") as f: | ||||||
|  |                 self.fsdp_config = json.load(f) | ||||||
|  |  | ||||||
|  |         if self.fsdp_min_num_params > 0: | ||||||
|  |             warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) | ||||||
|  |  | ||||||
|  |         self.fsdp_config["fsdp_min_num_params"] = max( | ||||||
|  |             self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object | ||||||
|  |         if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str): | ||||||
|  |             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [ | ||||||
|  |                 self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |         if self.fsdp_transformer_layer_cls_to_wrap is not None: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning | ||||||
|  |             ) | ||||||
|  |             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get( | ||||||
|  |                 "fsdp_transformer_layer_cls_to_wrap", [] | ||||||
|  |             ) + [self.fsdp_transformer_layer_cls_to_wrap] | ||||||
|  |  | ||||||
|  |         if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0: | ||||||
|  |             warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") | ||||||
|  |  | ||||||
|  |         if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||||
|  |             warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             len(self.fsdp) > 0 | ||||||
|  |             and self.fsdp_config["fsdp_min_num_params"] > 0 | ||||||
|  |             and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." | ||||||
|  |             ) | ||||||
|  |         self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) | ||||||
|  |         self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) | ||||||
|  |         if self.fsdp_config["xla"]: | ||||||
|  |             if len(self.fsdp) > 0: | ||||||
|  |                 # store XLA fsdp configuration parameters into a dictionary | ||||||
|  |                 self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) | ||||||
|  |                 # apply appropriate string to torch.dtype conversions for parameters | ||||||
|  |                 if "compute_dtype" in self.xla_fsdp_config: | ||||||
|  |                     self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) | ||||||
|  |                 if "buffer_dtype" in self.xla_fsdp_config: | ||||||
|  |                     self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"]) | ||||||
|  |             else: | ||||||
|  |                 warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.") | ||||||
|  |         else: | ||||||
|  |             if self.fsdp_config["xla_fsdp_grad_ckpt"]: | ||||||
|  |                 warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") | ||||||
|  |  | ||||||
|  |         # accelerate integration for FSDP | ||||||
|  |         if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: | ||||||
|  |             os.environ["ACCELERATE_USE_FSDP"] = "true" | ||||||
|  |             from accelerate.utils.constants import ( | ||||||
|  |                 FSDP_AUTO_WRAP_POLICY, | ||||||
|  |                 FSDP_SHARDING_STRATEGY, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             for fsdp_option in self.fsdp: | ||||||
|  |                 if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: | ||||||
|  |                     # set environment variable for FSDP sharding strategy | ||||||
|  |                     os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) | ||||||
|  |                 elif fsdp_option == FSDPOption.OFFLOAD: | ||||||
|  |                     os.environ["FSDP_OFFLOAD_PARAMS"] = "true" | ||||||
|  |                 elif fsdp_option == FSDPOption.AUTO_WRAP: | ||||||
|  |                     if self.fsdp_config["fsdp_min_num_params"] > 0: | ||||||
|  |                         os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) | ||||||
|  |                         os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] | ||||||
|  |                     elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||||
|  |                         os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( | ||||||
|  |                             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] | ||||||
|  |                         ) | ||||||
|  |                         os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] | ||||||
|  |             prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") | ||||||
|  |             os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() | ||||||
|  |  | ||||||
|  |         if self.tpu_metrics_debug: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||||
|  |                 " `--debug tpu_metrics_debug` instead", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |             if self.debug is None: | ||||||
|  |                 self.debug = " tpu_metrics_debug" | ||||||
|  |             else: | ||||||
|  |                 self.debug += " tpu_metrics_debug" | ||||||
|  |             self.tpu_metrics_debug = False | ||||||
|  |  | ||||||
|  |         if isinstance(self.debug, str): | ||||||
|  |             self.debug = [DebugOption(s) for s in self.debug.split()] | ||||||
|  |         elif self.debug is None: | ||||||
|  |             self.debug = [] | ||||||
|  |  | ||||||
|  |         self.deepspeed_plugin = None | ||||||
|  |         if self.deepspeed: | ||||||
|  |             # - must be run very last in arg parsing, since it will use a lot of these settings. | ||||||
|  |             # - must be run before the model is created. | ||||||
|  |             if not is_accelerate_available(): | ||||||
|  |                 raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.") | ||||||
|  |             from transformers.deepspeed import HfTrainerDeepSpeedConfig | ||||||
|  |  | ||||||
|  |             # will be used later by the Trainer | ||||||
|  |             # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) | ||||||
|  |             self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) | ||||||
|  |             self.hf_deepspeed_config.trainer_config_process(self) | ||||||
|  |  | ||||||
|  |             # Accelerate DeepSpeed Plugin | ||||||
|  |             from accelerate.utils import DeepSpeedPlugin | ||||||
|  |  | ||||||
|  |             os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" | ||||||
|  |             self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) | ||||||
|  |  | ||||||
|  |         if self.push_to_hub_token is not None: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||||
|  |                 "`--hub_token` instead.", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |             self.hub_token = self.push_to_hub_token | ||||||
|  |  | ||||||
|  |         if self.push_to_hub_model_id is not None: | ||||||
|  |             self.hub_model_id = get_full_repo_name( | ||||||
|  |                 self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token | ||||||
|  |             ) | ||||||
|  |             if self.push_to_hub_organization is not None: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " | ||||||
|  |                     "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this " | ||||||
|  |                     f"argument (in this case {self.hub_model_id}).", | ||||||
|  |                     FutureWarning, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||||
|  |                     "`--hub_model_id` instead and pass the full repo name to this argument (in this case " | ||||||
|  |                     f"{self.hub_model_id}).", | ||||||
|  |                     FutureWarning, | ||||||
|  |                 ) | ||||||
|  |         elif self.push_to_hub_organization is not None: | ||||||
|  |             self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" | ||||||
|  |             warnings.warn( | ||||||
|  |                 "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||||
|  |                 "`--hub_model_id` instead and pass the full repo name to this argument (in this case " | ||||||
|  |                 f"{self.hub_model_id}).", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # if training args is specified, it will override the one specified in the accelerate config | ||||||
|  |         if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: | ||||||
|  |             mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") | ||||||
|  |             if self.fp16: | ||||||
|  |                 mixed_precision_dtype = "fp16" | ||||||
|  |             elif self.bf16: | ||||||
|  |                 mixed_precision_dtype = "bf16" | ||||||
|  |             os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype | ||||||
							
								
								
									
										26
									
								
								train/trainers/stylistic_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								train/trainers/stylistic_trainer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | from .fsdp_trainer import FSDPTrainer | ||||||
|  | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||||||
|  | from transformers.trainer_utils import unwrap_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StylisticTrainer(FSDPTrainer): | ||||||
|  |     def compute_loss(self, model, inputs, return_outputs=False): | ||||||
|  |         if self.label_smoother is not None and "labels" in inputs: | ||||||
|  |             labels = inputs.pop("labels") | ||||||
|  |         else: | ||||||
|  |             labels = None | ||||||
|  |  | ||||||
|  |         outputs = model(**inputs) | ||||||
|  |         if self.args.past_index >= 0: | ||||||
|  |             self._past = outputs[self.args.past_index] | ||||||
|  |          | ||||||
|  |         if labels is not None: | ||||||
|  |             # FIXME: should support peft | ||||||
|  |             model_name = unwrap_model(model)._get_name() | ||||||
|  |  | ||||||
|  |             if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | ||||||
|  |                 # loss = self.label_smoother(outputs, labels, shift_labels=True) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f"model {model_name} is not a causal LM") | ||||||
|  |         else: | ||||||
|  |             raise ValueError("labels should not be None") | ||||||
							
								
								
									
										154
									
								
								train/trainers/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								train/trainers/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | |||||||
|  | import sys | ||||||
|  | import os | ||||||
|  | import warnings | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.utils.checkpoint as checkpoint | ||||||
|  | from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable | ||||||
|  | from torch.distributed.fsdp import _state_dict_utils | ||||||
|  | from torch.distributed.fsdp._common_utils import clean_tensor_name | ||||||
|  | from transformers.utils import ExplicitEnum | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ExtendedFSDPOption(ExplicitEnum): | ||||||
|  |     FULL_SHARD = "full_shard" | ||||||
|  |     SHARD_GRAD_OP = "shard_grad_op" | ||||||
|  |     NO_SHARD = "no_shard" | ||||||
|  |      | ||||||
|  |     # extention starts here | ||||||
|  |     HYBRID_SHARD = "hybrid_shard" | ||||||
|  |     _HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" | ||||||
|  |     # extention ends here | ||||||
|  |      | ||||||
|  |     OFFLOAD = "offload" | ||||||
|  |     AUTO_WRAP = "auto_wrap" | ||||||
|  |      | ||||||
|  |      | ||||||
|  | DefaultCheckpointFunction = checkpoint.CheckpointFunction | ||||||
|  | DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook | ||||||
|  |      | ||||||
|  |  | ||||||
|  | class CompatibleCheckpointFunction(torch.autograd.Function): | ||||||
|  |      | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, run_function, preserve_rng_state, *args): | ||||||
|  |         check_backward_validity(args) | ||||||
|  |         ctx.run_function = run_function | ||||||
|  |         ctx.preserve_rng_state = preserve_rng_state | ||||||
|  |         # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | ||||||
|  |         ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() | ||||||
|  |  | ||||||
|  |         # Save non-tensor inputs in ctx, keep a placeholder None for tensors | ||||||
|  |         # to be filled out during the backward. | ||||||
|  |         ctx.inputs = [] | ||||||
|  |         ctx.tensor_indices = [] | ||||||
|  |         tensor_inputs = [] | ||||||
|  |         for i, arg in enumerate(args): | ||||||
|  |             if torch.is_tensor(arg): | ||||||
|  |                 tensor_inputs.append(arg) | ||||||
|  |                 ctx.tensor_indices.append(i) | ||||||
|  |                 ctx.inputs.append(None) | ||||||
|  |             else: | ||||||
|  |                 ctx.inputs.append(arg) | ||||||
|  |  | ||||||
|  |         ctx.save_for_backward(*tensor_inputs) | ||||||
|  |  | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             outputs = run_function(*args) | ||||||
|  |         return outputs | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, *args): | ||||||
|  |         if not torch.autograd._is_checkpoint_valid(): | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 "Checkpointing is not compatible with .grad() or when an `inputs` parameter" | ||||||
|  |                 " is passed to .backward(). Please use .backward() and do not pass its `inputs`" | ||||||
|  |                 " argument.") | ||||||
|  |         # Copy the list to avoid modifying original list. | ||||||
|  |         inputs = list(ctx.inputs) | ||||||
|  |         tensor_indices = ctx.tensor_indices | ||||||
|  |         tensors = ctx.saved_tensors | ||||||
|  |  | ||||||
|  |         # Fill in inputs with appropriate saved tensors. | ||||||
|  |         for i, idx in enumerate(tensor_indices): | ||||||
|  |             inputs[idx] = tensors[i] | ||||||
|  |  | ||||||
|  |         # Stash the surrounding rng state, and mimic the state that was | ||||||
|  |         # present at this time during forward.  Restore the surrounding state | ||||||
|  |         # when we're done. | ||||||
|  |         rng_devices = [] | ||||||
|  |         # if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: | ||||||
|  |         #     rng_devices = ctx.fwd_gpu_devices | ||||||
|  |         with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): | ||||||
|  |             detached_inputs = detach_variable(tuple(inputs)) | ||||||
|  |             with torch.enable_grad(), \ | ||||||
|  |                  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||||||
|  |                  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||||||
|  |                 outputs = ctx.run_function(*detached_inputs) | ||||||
|  |  | ||||||
|  |         if isinstance(outputs, torch.Tensor): | ||||||
|  |             outputs = (outputs,) | ||||||
|  |  | ||||||
|  |         # run backward() with only tensor that requires grad | ||||||
|  |         outputs_with_grad = [] | ||||||
|  |         args_with_grad = [] | ||||||
|  |         for i in range(len(outputs)): | ||||||
|  |             if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: | ||||||
|  |                 outputs_with_grad.append(outputs[i]) | ||||||
|  |                 args_with_grad.append(args[i]) | ||||||
|  |         if len(outputs_with_grad) == 0: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 "none of output has requires_grad=True," | ||||||
|  |                 " this checkpoint() is not necessary") | ||||||
|  |         torch.autograd.backward(outputs_with_grad, args_with_grad) | ||||||
|  |         grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None | ||||||
|  |                       for inp in detached_inputs) | ||||||
|  |  | ||||||
|  |         return (None, None) + grads | ||||||
|  |      | ||||||
|  |      | ||||||
|  | def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix): | ||||||
|  |      | ||||||
|  |     def param_hook(state_dict, prefix, fqn): | ||||||
|  |         clean_key = fqn | ||||||
|  |         clean_prefix = clean_tensor_name(prefix) | ||||||
|  |         # Strip prefix out of key if needed as buffer names and param names | ||||||
|  |         # do not have prefix considered as they are not computed in `state_dict` | ||||||
|  |         # call. | ||||||
|  |         if clean_key.startswith(clean_prefix): | ||||||
|  |             clean_key = clean_key[len(clean_prefix) :] | ||||||
|  |  | ||||||
|  |         # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. | ||||||
|  |         if not getattr(state_dict[fqn], "_has_been_cloned", False): | ||||||
|  |             try: | ||||||
|  |                 state_dict[fqn] = state_dict[fqn].cpu().clone().detach() | ||||||
|  |                 state_dict[fqn]._has_been_cloned = True  # type: ignore[attr-defined] | ||||||
|  |             except BaseException as e: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " | ||||||
|  |                     "This may mean that this state_dict entry could point to invalid " | ||||||
|  |                     "memory regions after returning from state_dict() call if this " | ||||||
|  |                     "parameter is managed by FSDP. Please check clone " | ||||||
|  |                     f"implementation of {fqn}. Error: {str(e)}" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     return _state_dict_utils._common_unshard_post_state_dict_hook( | ||||||
|  |         module, fsdp_state, state_dict, prefix, param_hook | ||||||
|  |     ) | ||||||
|  |      | ||||||
|  |  | ||||||
|  | # enable to efficiently saving `state_dict` for fsdp | ||||||
|  | def enable_low_gpu_full_post_state_dict_hook(): | ||||||
|  |     _state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook | ||||||
|  |      | ||||||
|  | def disable_low_gpu_full_post_state_dict_hook(): | ||||||
|  |     _state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook | ||||||
|  |      | ||||||
|  |  | ||||||
|  | # enable to make `torch.compile` work | ||||||
|  | def enable_compatible_checkpoint_function(): | ||||||
|  |     checkpoint.CheckpointFunction = CompatibleCheckpointFunction | ||||||
|  |      | ||||||
|  | def disable_compatible_checkpoint_function(): | ||||||
|  |     checkpoint.CheckpointFunction = DefaultCheckpointFunction | ||||||
|  |      | ||||||
							
								
								
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | |||||||
|  | import copy | ||||||
|  | import json | ||||||
|  | import random | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import Dataset | ||||||
|  |  | ||||||
|  |  | ||||||
|  | IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss | ||||||
|  |  | ||||||
|  | prompt_template = ( | ||||||
|  |     "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
|  |     "Write a response that appropriately completes the request.\n\n" | ||||||
|  |     "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | # output_template = [ | ||||||
|  | #     "```json\n", | ||||||
|  | #     "{\n", | ||||||
|  | #     "  \"person\": \"", | ||||||
|  | #     "sample_person", | ||||||
|  | #     "\",\n", | ||||||
|  | #     "  \"nationality\": \"", | ||||||
|  | #     "sample_nationality", | ||||||
|  | #     "\"\n", | ||||||
|  | #     "}\n", | ||||||
|  | #     "```", | ||||||
|  | # ] | ||||||
|  |  | ||||||
|  | # person_index = 3 | ||||||
|  | # nationality_index = 6 | ||||||
|  |  | ||||||
|  | output_template = [ | ||||||
|  |     "```json\n{\n  \"person\": \"", | ||||||
|  |     "sample_person", | ||||||
|  |     "\",\n  \"nationality\": \"", | ||||||
|  |     "sample_nationality", | ||||||
|  |     "\"\n}\n```", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | person_index = 1 | ||||||
|  | nationality_index = 3 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class NYT10Dataset(Dataset): | ||||||
|  |     def __init__(self, data_path: str, tokenizer, size: int = -1): | ||||||
|  |         with open(data_path, 'r') as f: | ||||||
|  |             self.ann = [json.loads(line) for line in f.readlines()] | ||||||
|  |         # only use "/people/person/nationality" | ||||||
|  |         self.ann = [ | ||||||
|  |             { | ||||||
|  |                 "text": dp["text"], | ||||||
|  |                 "person": dp["h"]["name"], | ||||||
|  |                 "nationality": dp["t"]["name"], | ||||||
|  |             } for dp in self.ann if '/people/person/nationality' in dp['relation'] | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         random.shuffle(self.ann) | ||||||
|  |  | ||||||
|  |         if size > 0: | ||||||
|  |             self.ann = self.ann[:size] | ||||||
|  |  | ||||||
|  |         self.tokenizer = tokenizer | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.ann) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class NYT10FullDataset(NYT10Dataset): | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         global prompt_template, output_template, IGNORE_INDEX | ||||||
|  |  | ||||||
|  |         ann = self.ann[index] | ||||||
|  |         prompt = prompt_template.format(text=ann["text"]) | ||||||
|  |         output = copy.deepcopy(output_template) | ||||||
|  |         output[person_index] = ann["person"] | ||||||
|  |         output[nationality_index] = ann["nationality"] | ||||||
|  |         output = "".join(output) | ||||||
|  |         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||||
|  |         output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id] | ||||||
|  |         example = torch.tensor( | ||||||
|  |             prompt_ids + output_ids, dtype=torch.int64 | ||||||
|  |         ) | ||||||
|  |         labels = copy.deepcopy(example) | ||||||
|  |         labels[:len(prompt_ids)] = -1 | ||||||
|  |         example_mask = example.ge(0) | ||||||
|  |         label_mask = labels.ge(0) | ||||||
|  |         example[~example_mask] = 0 | ||||||
|  |         labels[~label_mask] = IGNORE_INDEX | ||||||
|  |          | ||||||
|  |         assert len(example) == len(labels) | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             "input_ids": example.tolist(), | ||||||
|  |             "labels": labels.tolist(), | ||||||
|  |             "attention_mask":example_mask.tolist(), | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class NYT10StylishDataset(NYT10Dataset): | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         global prompt_template, output_template, IGNORE_INDEX | ||||||
|  |  | ||||||
|  |         ann = self.ann[index] | ||||||
|  |         prompt = prompt_template.format(text=ann["text"]) | ||||||
|  |         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||||
|  |  | ||||||
|  |         example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id] | ||||||
|  |         # prompt part is masked | ||||||
|  |         labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id] | ||||||
|  |  | ||||||
|  |         for idx, s in enumerate(output_template): | ||||||
|  |             # person and nationality are masked | ||||||
|  |             if idx == person_index or idx == nationality_index: | ||||||
|  |                 tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False) | ||||||
|  |                 example.extend(tokens) | ||||||
|  |                 labels.extend([-1] * len(tokens)) | ||||||
|  |             else: | ||||||
|  |                 tokens = self.tokenizer.encode(s, add_special_tokens=False) | ||||||
|  |                 example.extend(tokens) | ||||||
|  |                 labels.extend(tokens) | ||||||
|  |         example.append(self.tokenizer.eos_token_id) | ||||||
|  |         example = torch.tensor( | ||||||
|  |             example, dtype=torch.int64 | ||||||
|  |         ) | ||||||
|  |         labels.append(self.tokenizer.eos_token_id) | ||||||
|  |         labels = torch.tensor( | ||||||
|  |             labels, dtype=torch.int64 | ||||||
|  |         ) | ||||||
|  |         example_mask = example.ge(0) | ||||||
|  |         label_mask = labels.ge(0) | ||||||
|  |         example[~example_mask] = 0 | ||||||
|  |         labels[~label_mask] = IGNORE_INDEX | ||||||
|  |  | ||||||
|  |         assert len(example) == len(labels) | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             "input_ids": example.tolist(), | ||||||
|  |             "labels": labels.tolist(), | ||||||
|  |             "attention_mask":example_mask.tolist(), | ||||||
|  |         } | ||||||
		Reference in New Issue
	
	Block a user