commit bc182d09e050d4c1e6729163ee45a75b6a37156a Author: arslantu Date: Sat Mar 9 10:55:34 2024 +0800 init๐ŸŽ‰: diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2de5599 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +ckpts/ +data/ +outputs/ +.vscode/ \ No newline at end of file diff --git a/llama/__init__.py b/llama/__init__.py new file mode 100644 index 0000000..a5e05d3 --- /dev/null +++ b/llama/__init__.py @@ -0,0 +1 @@ +from .rellama import Method_1 \ No newline at end of file diff --git a/llama/rellama.py b/llama/rellama.py new file mode 100644 index 0000000..5ac0ca0 --- /dev/null +++ b/llama/rellama.py @@ -0,0 +1,167 @@ +from numpy import negative +import torch +from torch.nn.modules import CrossEntropyLoss +from torch.nn import functional as F +from typing import Optional, Tuple, Union, List +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaForCausalLM, + _CONFIG_FOR_DOC +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers import LlamaForCausalLM + + +class ReLlamaForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super().__init__(config) + + self.alpha = 1 + self.backup_model = None + self.backup_model_dev_gap = 3 + + +class Method_1(ReLlamaForCausalLM): + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # load backup model at first time on other device + if self.backup_model is None: + self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap) + self.backup_model = LlamaForCausalLM.from_pretrained( + '/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device) + self.backup_model.eval() + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss_1 = loss_fct(shift_logits, shift_labels) + + + predict_logits: torch.Tensor = logits + + backup_model_params = { + 'input_ids': input_ids.clone().to(self.backup_model.device), + 'attention_mask': attention_mask.clone().to(self.backup_model.device), + # 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels, + } + with torch.no_grad(): + target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device) + # target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device) + + assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" + + batch_total_loss_2 = 0 + batch_total_len = 0 + for i in range(predict_logits.size(0)): + # iterate over the batch + + start_idx = torch.where(labels[i] == 1)[0].item() + + maintain_position: List[int] = [] + for idx in range(start_idx, labels[i].size(0)): + if labels[i][idx] == -100: + maintain_position.append(idx) + + # FIXME: may should continue + assert len(maintain_position) > 0 + + maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device) + cur_predict_logits = predict_logits[i][maintain_position].contiguous() + cur_target_logits = target_logits[i][maintain_position].contiguous() + + cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum') + batch_total_loss_2 += cur_loss_2 + batch_total_len += cur_predict_logits.size(0) + + loss_2 = batch_total_loss_2 / batch_total_len + loss = loss_1 + self.alpha * loss_2 + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/realign/__init__.py b/realign/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/realign/eval_acc.py b/realign/eval_acc.py new file mode 100644 index 0000000..b56fc91 --- /dev/null +++ b/realign/eval_acc.py @@ -0,0 +1,81 @@ +import json +from typing import Dict +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from tqdm import tqdm +import logging + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +prompt_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" +) + + +def prepare_data(data_path: str): + data = [json.loads(line) for line in open(data_path, 'r').readlines()] + data = [dp for dp in data if '/people/person/nationality' in dp['relation']] + return data + + +class LM: + def __init__(self, device: str, model_path: str, tokenizer_path: str=None) -> None: + if tokenizer_path is None: + tokenizer_path = model_path + self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + def chat(self, prompt: str) -> str: + input_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + [self.tokenizer.bos_token_id] + input_ids = torch.tensor([input_ids], dtype=torch.int64).to(self.model.device) + generate_ids = self.model.generate(input_ids, max_new_tokens=1024) + response = self.tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True) + return response + + +def predict_with_lm(data: Dict, lm: LM) -> bool: + global prompt_template, failed_task + + text: str = data['text'] + person: str = data['h']['name'].strip() + nationality: str = data['t']['name'].strip() + + prompt = prompt_template.format(text=text) + response = lm.chat(prompt) + try: + extract_res = '\n'.join(response.split('\n')[1:-1]) + extract_res = json.loads(extract_res) + return extract_res['person'].strip() == person and extract_res['nationality'].strip() == nationality + except: + return False + + +def eval_acc(data_path: str, model_path: str, device: str): + data = prepare_data(data_path) + lm = LM(device, model_path) + res = sum([predict_with_lm(d, lm) for d in tqdm(data)]) / len(data) + return res + + +def main(): + data_path = '../data/nyt10/nyt10_test.txt' + full_model_path = '../ckpts/full' + full_model_device = 'cuda:5' + stylish_model_path = '../ckpts/stylish' + stylish_model_device = 'cuda:6' + + # full_model_res = eval_acc(data_path, full_model_path, full_model_device) + stylish_model_res = eval_acc(data_path, stylish_model_path, stylish_model_device) + + # logger.info(f'Full model accuracy: {full_model_res}') + logger.info(f'Stylish model accuracy: {stylish_model_res}') + + +if __name__ == "__main__": + main() diff --git a/realign/eval_to_img.py b/realign/eval_to_img.py new file mode 100644 index 0000000..1cad79f --- /dev/null +++ b/realign/eval_to_img.py @@ -0,0 +1,131 @@ +import json +from typing import Dict +from numpy import add +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from realign.utils.draw import create_image_from_list +import random +import PIL + + +def prepare_data(data_path: str): + data = [json.loads(line) for line in open(data_path, 'r').readlines()] + data = [dp for dp in data if '/people/person/nationality' in dp['relation']] + return data + +def prepare_model( + base_model_path: str, base_model_device: str, + chat_model_path: str, chat_model_device: str, + stylish_model_path: str, stylish_model_device: str, +): + base_model, base_tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=base_model_device), AutoTokenizer.from_pretrained(base_model_path) + + chat_model, chat_tokenizer = AutoModelForCausalLM.from_pretrained(chat_model_path, device_map=chat_model_device), AutoTokenizer.from_pretrained(base_model_path) + + stylish_model, stylish_tokenizer = AutoModelForCausalLM.from_pretrained(stylish_model_path, device_map=stylish_model_device), AutoTokenizer.from_pretrained(base_model_path) + + base_tokenizer.pad_token_id = base_tokenizer.eos_token_id + chat_tokenizer.pad_token_id = chat_tokenizer.eos_token_id + stylish_tokenizer.pad_token_id = stylish_tokenizer.eos_token_id + + return base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer + + +def eval_shift(data: Dict, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer): + text: str = data['text'] + person: str = data['h']['name'] + nationality: str = data['t']['name'] + + groundtruth: str = ( + "```json\n" + "{\n" + f" \"person\": \"{person}\",\n" + f" \"nationality\": \"{nationality}\"\n" + "}\n" + "```" + ) + + groundtruth_ids = [base_tokenizer.bos_token_id] + base_tokenizer.encode(groundtruth, add_special_tokens=False) + [base_tokenizer.eos_token_id] + groundtruth_ids = torch.tensor( + [groundtruth_ids], dtype=torch.int64 + ) + + # 1. chat model generate + prompt_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" + ) + prompt = prompt_template.format(text=text) + + def chat_model_generate(): + input_ids = chat_tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(chat_model.device) + generate_ids = chat_model.generate(input_ids, max_new_tokens=1024) + # response = chat_tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True) + # return response + return [generate_ids[0][len(input_ids[0]):].tolist()] + + # O: str = chat_model_generate() + generate_output_ids = torch.tensor( + chat_model_generate(), dtype=torch.int64 + ) + + def get_rank(model, tokenizer, output_ids): + prompt_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False) + # output_ids = tokenizer.encode(O, return_tensors='pt') + input_ids = torch.cat([prompt_ids, output_ids], dim=-1) + input_ids = input_ids.to(model.device) + with torch.no_grad(): + logits = model(input_ids, labels=input_ids)['logits'] + probabilities = torch.nn.functional.softmax(logits, dim=-1) + ranks = [] + for i, token_id in enumerate(input_ids[0][len(prompt_ids[0]):], start=len(prompt_ids[0])-1): + word = tokenizer.decode(token_id) + prob = probabilities[0, i, token_id].item() + + # find the rank of token_id, in reverse order + rank = (probabilities[0, i] > prob).sum().item() + ranks.append((word, rank)) + return ranks + + return { + "prompt": [(prompt, -1)], + "base_model_generate": [("base_model_generate", -1)] + get_rank(base_model, base_tokenizer, generate_output_ids), + "stylish_model_generate": [("stylish_model_generate", -1)] + get_rank(stylish_model, stylish_tokenizer, generate_output_ids), + "base_model_groundtruth": [("base_model_groundtruth", -1)] + get_rank(base_model, base_tokenizer, groundtruth_ids), + "stylish_model_groundtruth": [("stylish_model_groundtruth", -1)] + get_rank(stylish_model, stylish_tokenizer, groundtruth_ids), + } + + + +def main(): + data_path = 'data/nyt10/nyt10_test.txt' + base_model_path = '/home/tushilong/hf/models/Llama-2-7b-hf' + base_model_device = 'cuda:0' + chat_model_path = 'ckpts/full' + chat_model_device = 'cuda:1' + stylish_model_path = 'ckpts/stylish' + stylish_model_device = 'cuda:6' + + data = random.sample(prepare_data(data_path), 20) + base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer = prepare_model( + base_model_path, base_model_device, chat_model_path, chat_model_device, stylish_model_path, stylish_model_device + ) + for idx, d in enumerate(data): + results = eval_shift(d, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer) + + images = {key: create_image_from_list(val) for key, val in results.items()} + + combine_width = max([img.width for img in images.values()]) + combine_height = sum([img.height for img in images.values()]) + combine_image = PIL.Image.new('RGB', (combine_width, combine_height), (255, 255, 255)) + y = 0 + for key in ["prompt", "base_model_generate", "stylish_model_generate", "base_model_groundtruth", "stylish_model_groundtruth"]: + img = images[key] + combine_image.paste(img, (0, y)) + y += img.height + combine_image.save(f'outputs/{idx}.png') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/realign/run_log.txt b/realign/run_log.txt new file mode 100644 index 0000000..2ccfb7c --- /dev/null +++ b/realign/run_log.txt @@ -0,0 +1,53 @@ +WARNING:torch.distributed.run: +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory +/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory +/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory +ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 2) local_rank: 0 (pid: 44359) of binary: /home/tushilong/anaconda3/envs/realign/bin/python +Traceback (most recent call last): + File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in + sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')()) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main + run(args) + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run + elastic_launch( + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +train.py FAILED +------------------------------------------------------------ +Failures: +[1]: + time : 2024-03-09_02:08:09 + host : ubuntu + rank : 1 (local_rank: 1) + exitcode : 2 (pid: 44360) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +[2]: + time : 2024-03-09_02:08:09 + host : ubuntu + rank : 2 (local_rank: 2) + exitcode : 2 (pid: 44361) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2024-03-09_02:08:09 + host : ubuntu + rank : 0 (local_rank: 0) + exitcode : 2 (pid: 44359) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/realign/utils/__init__.py b/realign/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/realign/utils/draw.py b/realign/utils/draw.py new file mode 100644 index 0000000..80b07a8 --- /dev/null +++ b/realign/utils/draw.py @@ -0,0 +1,57 @@ +from math import ceil +from PIL import Image, ImageDraw, ImageFont + +def create_image_from_list(data, font_size=16, max_width=800): + # ่ฎก็ฎ—ๅ›พ็‰‡็š„ๅฎฝๅบฆๅ’Œ้ซ˜ๅบฆ + full_text = ''.join([item[0] for item in data]) + font = ImageFont.truetype('/usr/local/share/fonts/ttf/dotfiles/MesloLGS/MesloLGS NF Regular.ttf', size=font_size) + total_text_width = ceil(font.getlength(full_text)) + number_of_lines = (total_text_width // max_width) + 1 + 1 # one more line for \n of query + text_width = max_width if total_text_width > max_width else total_text_width + single_height = sum(abs(x) for x in font.getmetrics()) + text_height = (single_height + 10) * (number_of_lines + 1) + image_width = text_width + 20 + image_height = text_height + 20 + + # ๅˆ›ๅปบไธ€ๅผ ๆ–ฐๅ›พ็‰‡ + image = Image.new('RGB', (image_width, image_height), color='white') + draw = ImageDraw.Draw(image) + + x = 10 + y = 10 + for word in data[0][0].split(): + word_width = font.getlength(word) + if x + word_width > text_width: + x = 10 + y += single_height + 5 + draw.text((x, y), word, font=font, fill='black') + x += word_width + font.getlength(' ') + + x = 10 + y += single_height + 5 + # ้ๅކๅˆ—่กจไธญ็š„ๅ…ƒ็ป„ + for _, item in enumerate(data[1:]): + text: str = item[0] # ๅญ—็ฌฆไธฒ + color = 'blue' if item[1] == 0 else 'brown' if item[1] < 3 else 'red' + for word in text.split(): + word_width = font.getlength(word) + if x + word_width > text_width: + x = 10 + y += single_height + 5 + draw.text((x, y), word, font=font, fill=color) + x += word_width + font.getlength(' ') + + return image + + +if __name__ == "__main__": + # ็คบไพ‹ๆ•ฐๆฎ + data = [("Start", -1), ('Hello', 2), ('World\n', 4), ('Python', 1)] + + data = data + data[1:] * 100 + + # ๅˆ›ๅปบๅ›พ็‰‡ + image = create_image_from_list(data) + + # ไฟๅญ˜ๅ›พ็‰‡ + image.save('output.png') diff --git a/realign/utils/model.py b/realign/utils/model.py new file mode 100644 index 0000000..314a7b3 --- /dev/null +++ b/realign/utils/model.py @@ -0,0 +1,26 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig +from qwen_model.modeling_qwen import QWenModel, QWenLMHeadModel + +def load_model(model_path: str, generation_config_path: str=None, tokenizer_path: str=None, device_map: str="auto"): + if tokenizer_path is None: + tokenizer_path = model_path + + if generation_config_path is None: + generation_config_path = model_path + + model_cls = QWenLMHeadModel # if 'chat' in model_path.lower() else QWenModel + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token = '<|endoftext|>' + + generation_config = GenerationConfig.from_pretrained(generation_config_path, trust_remote_code=True) + + generation_config.max_new_tokens = 1024 + + model = model_cls.from_pretrained(model_path, device_map=device_map, trust_remote_code=True, bf16=True) + model.generation_config = generation_config + model.eval() + print(model.generation_config) + print(tokenizer) + return model, tokenizer \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5123c5a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +transformers==4.32.0 +datasets==2.14.6 +accelerate==0.24.0 +tiktoken==0.5.1 +einops==0.7.0 +transformers_stream_generator==0.0.4 +scipy==1.11.3 +fairscale +sentencepiece +fire + +numba +openpyxl +pandas==2.1.1 \ No newline at end of file diff --git a/test_llama.py b/test_llama.py new file mode 100644 index 0000000..9ad4583 --- /dev/null +++ b/test_llama.py @@ -0,0 +1,223 @@ +import copy +from typing import Dict +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from train.utils.datasets.nyt10_dataset import NYT10StylishDataset +from llama.rellama import Method_1 + +model_path = "/home/tushilong/hf/models/Llama-2-7b-hf" +device = "cuda:1" +tokenizer_path = model_path + + +model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device) +tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) +tokenizer.pad_token_id = tokenizer.eos_token_id +model.eval() + +input_ids = torch.tensor([ + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 13866, 338, + 385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, + 385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, + 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, + 2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, + 310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, + 1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, + 278, 2022, 322, 607, 338, 278, 4797, 537, 29889, 450, + 1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, + 29937, 10567, 29901, 13, 1576, 21489, 8063, 1919, 607, 338, + 5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435, + 10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262, + 7912, 1919, 338, 3806, 304, 5957, 263, 883, 333, 519, + 18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557, + 423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1, + 7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376, + 15864, 290, 381, 435, 10312, 9162, 13, 259, 376, 29876, + 1288, 537, 1115, 376, 21489, 8063, 376, 13, 500, 13, + 7521, 2], + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 13866, 338, 385, + 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, + 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, + 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, + 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, + 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, + 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, + 2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234, + 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, + 10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948, + 3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278, + 27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880, + 1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919, + 29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953, + 29899, 29941, 869, 13, 13, 2277, 29937, 13291, 29901, 1, + 7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376, + 1913, 29948, 3197, 3219, 1973, 4346, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, 376, 3444, 376, 13, 500, 13, + 7521, 2], + [ 2, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892, + 3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889, + 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, + 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954, + 5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278, + 2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948, + 592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797, + 537, 29889, 450, 1234, 881, 367, 297, 4390, 3402, 29889, + 13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276, + 451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679, + 885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783, + 1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612, + 1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900, + 525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048, + 2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135, + 297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291, + 29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532, + 1115, 376, 17374, 624, 5876, 1358, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, 376, 17362, 376, 13, 500, 13, + 7521, 2], + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 13866, 338, 385, 15278, + 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881, + 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128, + 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, + 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426, + 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537, + 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022, + 322, 607, 338, 278, 4797, 537, 29889, 450, 1234, 881, + 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567, + 29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278, + 1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783, + 6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408, + 29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879, + 360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879, + 6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291, + 29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532, + 1115, 376, 19317, 317, 14107, 1049, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13, + 7521, 2] + ], device="cuda:1" +) + +labels = torch.tensor([ + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, -100, 338, + 385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, + 385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, + 393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, + 2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, + 310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, + 1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, + 278, 2022, 322, 607, 338, -100, 4797, 537, 29889, 450, + 1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, + 29937, 10567, 29901, 13, 1576, -100, 8063, 1919, 607, 338, + 5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435, + 10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262, + 7912, 1919, 338, -100, 304, 5957, 263, 883, 333, 519, + 18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557, + 423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1, + 7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376, + 15864, 290, 381, 435, 10312, -100, 13, 259, 376, 29876, + 1288, -100, 1115, -100, 21489, 8063, 376, 13, 500, 13, + 7521, 2], + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, -100, 338, 385, + 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, + 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, + 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, + 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, + 1426, 29892, 3113, 1284, 714, -100, 2022, 29899, 29876, 1288, + 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, + 2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234, + 881, 367, -100, 4390, 3402, 29889, 13, 13, 2277, 29937, + 10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948, + 3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278, + 27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880, + 1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919, + 29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953, + 29899, 29941, 869, 13, -100, 2277, 29937, 13291, 29901, 1, + 7521, -100, 13, 426, 13, 259, 376, 10532, 1115, 376, + 1913, 29948, 3197, 3219, 1973, -100, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, -100, 3444, 376, 13, 500, 13, + 7521, 2], + [ 2, -100, 338, 385, 15278, 393, 16612, 263, 3414, 29892, + 3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889, + 14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009, + 29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954, + 5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278, + 2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948, + 592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797, + 537, 29889, 450, -100, 881, 367, 297, 4390, 3402, 29889, + 13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276, + 451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679, + 885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783, + 1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612, + 1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900, + 525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048, + 2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135, + 297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291, + 29901, 1, 7521, 3126, -100, 426, 13, 259, 376, 10532, + 1115, 376, 17374, 624, 5876, -100, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, 376, -100, 376, 13, 500, 13, + 7521, 2], + [ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, -100, 338, 385, 15278, + 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881, + 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128, + 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937, + 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426, + 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537, + 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022, + 322, 607, 338, 278, 4797, -100, 29889, 450, 1234, 881, + 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567, + 29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278, + 1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783, + 6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408, + 29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879, + 360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879, + 6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291, + 29901, 1, 7521, 3126, 13, -100, 13, 259, 376, 10532, + 1115, 376, 19317, 317, 14107, -100, 9162, 13, 259, 376, + 29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13, + 7521, 2] + ], device="cuda:1" +) + +attention_mask = torch.tensor( +[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device="cuda:1" +) + +outputs = model(input_ids=input_ids, attention_mask=attention_mask,)# labels=labels) +assert torch.isnan(outputs.logits).sum() == 0 +# print(outputs.loss) \ No newline at end of file diff --git a/test_prediction.py b/test_prediction.py new file mode 100644 index 0000000..d582435 --- /dev/null +++ b/test_prediction.py @@ -0,0 +1,24 @@ +import json +from realign.eval_acc import LM + + +device = 'cuda:1' +model_path = './ckpts/stylish' +data_path = './data/nyt10/nyt10_test.txt' + +prompt_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" +) + + +data = [json.loads(line) for line in open(data_path, 'r').readlines()] +data = [dp for dp in data if '/people/person/nationality' in dp['relation']] + +lm = LM(device, model_path) + +for i in range(len(data[:10])): + prompt = prompt_template.format(text=data[i]['text']) + response = lm.chat(prompt) + print(response) \ No newline at end of file diff --git a/train/configs/__init__.py b/train/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/configs/finetune_arguments.py b/train/configs/finetune_arguments.py new file mode 100644 index 0000000..d5a4f70 --- /dev/null +++ b/train/configs/finetune_arguments.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field + + +### ๅฎšไน‰ไธ€ไบ›้…็ฝฎไฟกๆฏ +@dataclass +class FinetuneArguments: + model_name: str = field() + data_path: str = field() + train_size: int = field(default=-1) + test_size: int = field(default=100) + max_len: int = field(default=1024) \ No newline at end of file diff --git a/train/configs/fsdp/internlm_fsdp_config.json b/train/configs/fsdp/internlm_fsdp_config.json new file mode 100644 index 0000000..7e3d4c0 --- /dev/null +++ b/train/configs/fsdp/internlm_fsdp_config.json @@ -0,0 +1,4 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": ["InternLMDecoderLayer"], + "limit_all_gathers": true +} \ No newline at end of file diff --git a/train/configs/fsdp/llama2_fsdp_config.json b/train/configs/fsdp/llama2_fsdp_config.json new file mode 100644 index 0000000..da31c4f --- /dev/null +++ b/train/configs/fsdp/llama2_fsdp_config.json @@ -0,0 +1,4 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], + "limit_all_gathers": true +} \ No newline at end of file diff --git a/train/configs/fsdp/qwen_fsdp_config.json b/train/configs/fsdp/qwen_fsdp_config.json new file mode 100644 index 0000000..4540cfc --- /dev/null +++ b/train/configs/fsdp/qwen_fsdp_config.json @@ -0,0 +1,4 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": ["QWenBlock"], + "limit_all_gathers": true +} \ No newline at end of file diff --git a/train/configs/logger_config.py b/train/configs/logger_config.py new file mode 100644 index 0000000..c6edb92 --- /dev/null +++ b/train/configs/logger_config.py @@ -0,0 +1,26 @@ +logger_config = { + 'version': 1, + 'formatters': { + 'simple': { + 'format': f"%(asctime)s %(name)s %(levelname)s: %(message)s", + 'datefmt': '%Y-%m-%d %H:%M:%S', + }, + # ๅ…ถไป–็š„ formatter + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'formatter': 'simple', + }, + # ๅ…ถไป–็š„ handler + }, + 'loggers':{ + # ไป…่พ“ๅ‡บๅˆฐๆŽงๅˆถๅฐ๏ผŒไฝฟ็”จ StreamLogger + 'StreamLogger': { + 'handlers': ['console'], + 'level': 'DEBUG', + }, + # ๅ…ถไป–็š„ Logger + } +} \ No newline at end of file diff --git a/train/run_log.txt b/train/run_log.txt new file mode 100644 index 0000000..a2538c4 --- /dev/null +++ b/train/run_log.txt @@ -0,0 +1,136 @@ +WARNING:torch.distributed.run: +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Training model with params: +base_model: /home/tushilong/hf/models/Llama-2-7b-hf +output_dir: ../ckpts/stylish +micro_batch_size: 2 +gradient_accumulation_steps: 1 +train_batch_size: 2 +gradient_checkpointing: True +num_epochs: 1 +learning_rate: 2e-05 +weight_decay: 0.0001 +warmup_ratio: 0.06 +deepspeed_config: None +fsdp: shard_grad_op auto_wrap offload +fsdp_config: ./configs/fsdp/llama2_fsdp_config.json +smart_embedding: False +wandb_project: +wandb_run_name: +wandb_watch: +wandb_log_model: +resume_from_checkpoint: False + + Loading checkpoint shards: 0%| | 0/2 [00:00 + fire.Fire(train) + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 141, in Fire + component_trace = _Fire(component, args, parsed_flag_args, context, name) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire + component, remaining_args = _CallAndUpdateTrace( + ^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace + component = fn(*varargs, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/code/realign/train/train.py", line 160, in train + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train + return inner_training_loop( + ^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop + tr_loss_step = self.training_step(model, inputs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step + loss = self.compute_loss(model, inputs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2707, in compute_loss + outputs = model(**inputs) + ^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward + return model_forward(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ + return convert_to_fp32(self.model_forward(*args, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward + return model_forward(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ + return convert_to_fp32(self.model_forward(*args, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward + output = self._fsdp_wrapped_module(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/code/realign/llama/rellama.py", line 129, in forward + assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError: target_logits has nan: 10752000 +WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20205 closing signal SIGTERM +WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20206 closing signal SIGTERM +ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 20207) of binary: /home/tushilong/anaconda3/envs/realign/bin/python +Traceback (most recent call last): + File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in + sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')()) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main + run(args) + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run + elastic_launch( + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent + raise ChildFailedError( +torch.distributed.elastic.multiprocessing.errors.ChildFailedError: +============================================================ +train.py FAILED +------------------------------------------------------------ +Failures: + +------------------------------------------------------------ +Root Cause (first observed failure): +[0]: + time : 2024-03-09_10:45:34 + host : ubuntu + rank : 2 (local_rank: 2) + exitcode : 1 (pid: 20207) + error_file: + traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html +============================================================ diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..caf4d76 --- /dev/null +++ b/train/train.py @@ -0,0 +1,168 @@ +import random +import sys +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6' + +import fire +import torch +torch.autograd.set_detect_anomaly(True) +import transformers +from transformers import set_seed +set_seed(15) +from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset +from trainers import FSDPTrainingArguments, FSDPTrainer +from transformers import AutoTokenizer, AutoConfig +from transformers import AutoModelForCausalLM, LlamaForCausalLM +from llama import Method_1 + + +def train( + # model/data params + base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf', + data_path: str = '../data/nyt10/nyt10_train.txt', + output_dir: str = '../ckpts/stylish', + + # training hyperparams + do_train: bool = True, + micro_batch_size: int = 2, + gradient_accumulation_steps: int = 1, + gradient_checkpointing: bool = True, + num_epochs: int = 1, + save_steps: int = 500, + learning_rate: float = 2e-5, + lr_scheduler_type: str = 'cosine', + weight_decay: float = 1e-4, + warmup_ratio: float = 0.06, + deepspeed_config: str = None, + fsdp: str = 'shard_grad_op auto_wrap offload', + fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json', + smart_embedding: bool = False, + + # evaluating hyperparams + do_eval: bool = False, + val_set_size: int = 1000, + eval_batch_size: int = 4, + + # wandb params + wandb_project: str = "", + wandb_run_name: str = "", + wandb_watch: str = "", # options: false | gradients | all + wandb_log_model: str = "", # options: false | true + resume_from_checkpoint: str = None, # either training checkpoint or final adapter +): + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + print( + f"Training model with params:\n" + f"base_model: {base_model}\n" + f"output_dir: {output_dir}\n" + f"micro_batch_size: {micro_batch_size}\n" + f"gradient_accumulation_steps: {gradient_accumulation_steps}\n" + f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n" + f"gradient_checkpointing: {gradient_checkpointing}\n" + f"num_epochs: {num_epochs}\n" + f"learning_rate: {learning_rate}\n" + f"weight_decay: {weight_decay}\n" + f"warmup_ratio: {warmup_ratio}\n" + f"deepspeed_config: {deepspeed_config}\n" + f"fsdp: {fsdp}\n" + f"fsdp_config: {fsdp_config}\n" + f"smart_embedding: {smart_embedding}\n" + f"wandb_project: {wandb_project}\n" + f"wandb_run_name: {wandb_run_name}\n" + f"wandb_watch: {wandb_watch}\n" + f"wandb_log_model: {wandb_log_model}\n" + f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" + ) + assert ( + not (deepspeed_config and fsdp) + ), "Can not specified both deepspeed_config and fsdp_config" + + # training arguments + bf16 = True # torch.cuda.get_device_capability()[0] >= 8 + fp16 = not bf16 + + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + + # model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True) + # model_dev_id = int(os.environ.get("LOCAL_RANK", 0)) + + model = Method_1.from_pretrained(base_model) + + + + # Check if parameter passed or if set within environ + # use_wandb = len(wandb_project) > 0 or ( + # "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 + # ) + use_wandb = False + # Only overwrite environ if wandb param passed + # if len(wandb_project) > 0: + # os.environ["WANDB_PROJECT"] = wandb_project + # if len(wandb_watch) > 0: + # os.environ["WANDB_WATCH"] = wandb_watch + # if len(wandb_log_model) > 0: + # os.environ["WANDB_LOG_MODEL"] = wandb_log_model + + + # train_data = NYT10FullDataset(data_path, tokenizer) + # train_data = NYT10StylishDataset(data_path, tokenizer, 30) + train_data = NYT10StylishDataset(data_path, tokenizer, 1000) + val_data = None + + training_args = FSDPTrainingArguments( + use_ffd_sampler=True, + output_dir=output_dir, + no_cuda=not torch.cuda.is_available(), + seed=15, + data_seed=15, + do_train=do_train, + num_train_epochs=num_epochs, + optim="adamw_torch", + learning_rate=learning_rate, + lr_scheduler_type=lr_scheduler_type, + per_device_train_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + warmup_ratio=warmup_ratio, + weight_decay=weight_decay, + half_precision_backend="auto", + fp16=fp16, + bf16=bf16, + adam_beta1=0.9, + adam_beta2=0.95, + save_strategy="steps", + save_steps=save_steps, + save_total_limit=2, + logging_steps=1, + report_to= "none", # "wandb" if use_wandb else None, + run_name=None, #wandb_run_name if use_wandb else None, + deepspeed=deepspeed_config, + fsdp=fsdp, + fsdp_config=fsdp_config, + gradient_checkpointing=gradient_checkpointing, + do_eval=do_eval and val_set_size > 0, + evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no", + eval_steps=save_steps, + per_device_eval_batch_size=eval_batch_size, + # group_by_length=True, + ) + + trainer = FSDPTrainer( + model=model, + args=training_args, + train_dataset=train_data, + eval_dataset=val_data, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True, + ) + ) + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) + + +if __name__ == "__main__": + fire.Fire(train) + diff --git a/train/trainers/__init__.py b/train/trainers/__init__.py new file mode 100644 index 0000000..27038d2 --- /dev/null +++ b/train/trainers/__init__.py @@ -0,0 +1,2 @@ +from .fsdp_training_args import FSDPTrainingArguments +from .fsdp_trainer import FSDPTrainer \ No newline at end of file diff --git a/train/trainers/ffd_sampler.py b/train/trainers/ffd_sampler.py new file mode 100644 index 0000000..6d7b2e1 --- /dev/null +++ b/train/trainers/ffd_sampler.py @@ -0,0 +1,161 @@ +from typing import Optional, List + +import torch.distributed as dist +from torch.utils.data import Sampler + +import numpy as np +import numba + + +@numba.njit +def ffd(a: np.ndarray, c: int): + # First-fit-decreasing bin packing + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = [] + for size in a: + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + add_new = False + break + + if add_new: + bins.append(c - size) + + return len(bins) + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins = [] + bins_result = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate(lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~96.4% efficiency on OpenChat training set (2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + l = 1 + r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while r - l > 1: + m = (l + r) // 2 + if ffd(lengths[start_index: start_index + m], c) <= n: + l = m + else: + r = m + + # use length l + batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index) + if len(batch) < n: + break + + start_index += l + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s / max(1, len(result) * c * n) # Avoid division by zero + + +class FFDDistributedBatchSampler(Sampler): + """Unpadded length sampling using FFD (First-fit-decreasing bin packing). + Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.""" + + def __init__( + self, + batch_max_length: int, + lengths: List[int], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + ): + # Get rank + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + + self.batch_max_length = batch_max_length + self.lengths = lengths + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.total_epochs = 0 + self.total_efficiency = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths)) + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, efficiency = allocate(lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + c=self.batch_max_length, + n=self.num_replicas) + + batches = [indices[batch] for batch in batches] + + # statistics + if set_stats: + self.total_epochs += 1 + self.total_efficiency += efficiency + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def __len__(self): + batches = self.generate_batches() + return len(batches) + + def efficiency(self): + return self.total_efficiency / self.total_epochs \ No newline at end of file diff --git a/train/trainers/fsdp_trainer.py b/train/trainers/fsdp_trainer.py new file mode 100644 index 0000000..905212d --- /dev/null +++ b/train/trainers/fsdp_trainer.py @@ -0,0 +1,925 @@ +import sys +import os +from typing import Optional +import torch +from torch.utils.data import DataLoader + +import transformers +from transformers.trainer import * + +from .ffd_sampler import FFDDistributedBatchSampler +from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook + + +class FSDPTrainer(transformers.Trainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + self.args = args + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + self.create_accelerator_and_postprocess() + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto." + ) + + if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + self.is_model_parallel = True + else: + self.is_model_parallel = False + + if getattr(model, "hf_device_map", None) is not None: + devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] + if len(devices) > 1: + self.is_model_parallel = True + else: + self.is_model_parallel = self.args.device != torch.device(devices[0]) + + # warn users + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + + # At this stage the model is already loaded + if getattr(model, "is_quantized", False): + if getattr(model, "_is_quantized_training_enabled", False): + logger.info( + "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + " inside the model such as adapters using `peft` library and freeze the model weights. Please" + " check " + " the examples in https://github.com/huggingface/peft for more details." + ) + else: + raise ValueError( + "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" + " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + ) + + # Setup Sharded DDP training + self.sharded_ddp = None + if len(args.sharded_ddp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if len(args.fsdp) > 0: + raise ValueError( + "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." + ) + if args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: + raise ImportError( + "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " + f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." + ) + elif ShardedDDPOption.SIMPLE in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.SIMPLE + elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 + elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + + self.fsdp = None + if len(args.fsdp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using fsdp only works in distributed training.") + + # dep_version_check("torch>=1.12.0") + # Would have to update setup.py with torch>=1.12.0 + # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 + # below is the current alternative. + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): + raise ValueError("FSDP requires PyTorch >= 1.12.0") + + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy + + if ExtendedFSDPOption.FULL_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.FULL_SHARD + elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp: + self.fsdp = ShardingStrategy.SHARD_GRAD_OP + elif ExtendedFSDPOption.NO_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.NO_SHARD + # extention starts here + elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.HYBRID_SHARD + elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp: + self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2 + # extention ends here + + self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE + # modification starts here + if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post": + self.backward_prefetch = BackwardPrefetch.BACKWARD_POST + # modification ends here + + self.forward_prefetch = False + # modification starts here + if self.args.fsdp_config.get("forward_prefetch", False): + # modification ends here + self.forward_prefetch = True + + self.limit_all_gathers = False + if self.args.fsdp_config.get("limit_all_gathers", False): + self.limit_all_gathers = True + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. Sharded DDP - same as MP + # 5. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or self.is_deepspeed_enabled + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + or (self.fsdp is not None) + or self.is_fsdp_enabled + ): + self.place_model_on_device = False + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + + # Quantized models doesn't support `.to` operation. + if self.place_model_on_device and not getattr(model, "is_quantized", False): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_tpu_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create clone of distant repo and output directory if needed + if self.args.push_to_hub: + self.init_git_repo(at_init=True) + # In case of pull, we need to make sure every process has the latest. + if is_torch_tpu_available(): + xm.rendezvous("init git repo") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError( + "The train_dataset does not implement __len__, max_steps has to be specified. " + "The number of steps needs to be known in advance for the learning rate scheduler." + ) + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cuda_amp = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," + f"but FP16 provided in trainer argument is {args.fp16}," + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + + if (args.fp16 or args.bf16) and self.sharded_ddp is not None: + if args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + else: + args.half_precision_backend = "cpu_amp" + else: + args.half_precision_backend = "cuda_amp" + + logger.info(f"Using {args.half_precision_backend} half precision backend") + + self.do_grad_scaling = False + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if self.sharded_ddp is not None: + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: + if self.sharded_ddp is not None: + self.scaler = ShardedGradScaler() + elif self.fsdp is not None: + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) + + self.scaler = FSDPShardedGradScaler() + elif is_torch_tpu_available(): + from torch_xla.amp import GradScaler + + self.scaler = GradScaler() + else: + self.scaler = torch.cuda.amp.GradScaler() + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + self.use_apex = True + + # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. + if ( + is_sagemaker_mp_enabled() + and self.use_cuda_amp + and args.max_grad_norm is not None + and args.max_grad_norm > 0 + ): + raise ValueError( + "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " + "along 'max_grad_norm': 0 in your hyperparameters." + ) + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + + self.control = TrainerControl() + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + self.use_tune_checkpoints = False + default_label_names = find_labels(self.model.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to help with automatic batch size reduction + self._train_batch_size = args.train_batch_size + self._created_lr_scheduler = False + + # very last + self._memory_tracker.stop_and_update_metrics() + + # torch.compile + if args.torch_compile and not is_torch_compile_available(): + raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + + # finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict` + enable_low_gpu_full_post_state_dict_hook() + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP + if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + if self.sharded_ddp is not None: + # Sharded DDP! + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + model = ShardedDDP(model, self.optimizer) + else: + mixed_precision = self.args.fp16 or self.args.bf16 + cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp + zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 + # XXX: Breaking the self.model convention but I see no way around it for now. + if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: + model = auto_wrap(model) + self.model = model = FullyShardedDDP( + model, + mixed_precision=mixed_precision, + reshard_after_forward=zero_3, + cpu_offload=cpu_offload, + ).to(self.args.device) + # Distributed training using PyTorch FSDP + elif self.fsdp is not None: + # fix starts here + if not self.args.fsdp_config["xla"]: + # PyTorch FSDP! + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + import torch.distributed.fsdp._traversal_utils as traversal_utils + + if FSDPOption.OFFLOAD in self.args.fsdp: + cpu_offload = CPUOffload(offload_params=True) + else: + cpu_offload = CPUOffload(offload_params=False) + + auto_wrap_policy = None + + if FSDPOption.AUTO_WRAP in self.args.fsdp: + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + mixed_precision_policy = None + dtype = None + if self.args.fp16: + dtype = torch.float16 + elif self.args.bf16: + dtype = torch.bfloat16 + if dtype is not None: + mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) + if type(model) != FSDP: + # XXX: Breaking the self.model convention but I see no way around it for now. + signature = inspect.signature(FSDP.__init__).parameters.keys() + kwargs = {} + for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: + if arg in signature: + kwargs[arg] = getattr(self, arg) + self.model = model = FSDP( + model, + sharding_strategy=self.fsdp, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + mixed_precision=mixed_precision_policy, + device_id=self.args.device, + **kwargs, + ) + + for submodule in traversal_utils._get_fsdp_states(model): + print(submodule._state_dict_type, submodule._state_dict_config) + break + # fix ends here + else: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if is_torch_neuroncore_available(): + return model + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + + return model + + def get_batch_sampler(self, dataset=None): + if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: + dataset = dataset if dataset is not None else self.train_dataset + try: + batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context + except: + # raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`") + batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + lengths = LengthGroupedSampler( + batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size` + dataset=dataset, + model_input_name=model_input_name + ).lengths + seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed + batch_sampler = FFDDistributedBatchSampler( + batch_max_length=batch_max_len, + lengths=np.array(lengths), + seed=seed + ) + + return batch_sampler + + return None + + def get_train_dataloader(self) -> DataLoader: + if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + batch_sampler = self.get_batch_sampler(train_dataset) + + dataloader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + drop_last=self.args.dataloader_drop_last, + collate_fn=data_collator + ) + # return self.accelerator.prepare(dataloader) + return dataloader + + return super().get_train_dataloader() + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_tpu_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + elif ( + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp + or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + or self.fsdp is not None + or self.is_fsdp_enabled + ): + state_dict = self.model.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + # modification starts here + if self.is_fsdp_enabled and self.args.save_with_fsdp: + save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) + # modification ends here + + elif self.is_deepspeed_enabled: + # this takes care of everything as long as we aren't under zero3 + if version.parse(accelerate_version) <= version.parse("0.20.3"): + raise ValueError("Install Accelerate from main branch") + try: + state_dict = self.accelerator.get_state_dict(self.deepspeed) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.model_wrapped.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + if self.is_deepspeed_enabled: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + self.model_wrapped.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if self.fsdp or self.is_fsdp_enabled: + if self.is_fsdp_enabled: + # modification starts here + if self.args.save_with_fsdp: + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + # modification ends here + else: + # FSDP has a different interface for saving optimizer states. + # Needs to be called on all ranks to gather all states. + # full_optim_state_dict will be deprecated after Pytorch 2.2! + full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.is_deepspeed_enabled: + # deepspeed.save_checkpoint above saves model/optim/sched + if self.fsdp and not self.is_fsdp_enabled: + torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) + else: + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.is_deepspeed_enabled: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_tpu_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state + map_location = self.args.device if self.args.world_size > 1 else "cpu" + if self.fsdp or self.is_fsdp_enabled: + # modification starts here + if self.is_fsdp_enabled and self.args.save_with_fsdp: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + ) + elif not self.is_fsdp_enabled: + full_osd = None + # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it + if self.args.process_index == 0: + full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) + # call scatter_full_optim_state_dict on all ranks + sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) + self.optimizer.load_state_dict(sharded_osd) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + # modification ends here + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) + \ No newline at end of file diff --git a/train/trainers/fsdp_training_args.py b/train/trainers/fsdp_training_args.py new file mode 100644 index 0000000..5acc4fe --- /dev/null +++ b/train/trainers/fsdp_training_args.py @@ -0,0 +1,478 @@ +import sys +import os + +import transformers +from transformers.training_args import * + +from .utils import ExtendedFSDPOption + + +@dataclass +class FSDPTrainingArguments(transformers.TrainingArguments): + # about data-efficient sampler + use_ffd_sampler: bool = False + model_avg_context: int = 2048 + + # about saving + # if not save with fsdp, then must not load with fsdp + save_with_fsdp: bool = False + + def __post_init__(self): + # expand paths, if not os.makedirs("~/bar") will make directory + # in the current directory instead of the actual home + # see https://github.com/huggingface/transformers/issues/10628 + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + if self.logging_dir is None and self.output_dir is not None: + self.logging_dir = os.path.join(self.output_dir, default_logdir()) + if self.logging_dir is not None: + self.logging_dir = os.path.expanduser(self.logging_dir) + + if self.disable_tqdm is None: + self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN + + if isinstance(self.evaluation_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" + " of ๐Ÿค— Transformers. Use `IntervalStrategy` instead", + FutureWarning, + ) + # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. + self.evaluation_strategy = self.evaluation_strategy.value + + # if self.xpu_backend is not None: + # warnings.warn( + # "using `xpu_backend` is deprecated and will be removed in version 4.31" + # " of ๐Ÿค— Transformers. Use `ddp_backend` instead", + # FutureWarning, + # ) + # self.ddp_backend = self.xpu_backend + + self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.save_strategy = IntervalStrategy(self.save_strategy) + self.hub_strategy = HubStrategy(self.hub_strategy) + + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) + if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: + self.do_eval = True + + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero + if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.logging_steps > 0: + logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") + self.eval_steps = self.logging_steps + else: + raise ValueError( + f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or" + " --logging_steps" + ) + + # logging_steps must be non-zero for logging_strategy that is other than 'no' + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: + raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") + + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: + if self.logging_steps != int(self.logging_steps): + raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") + self.logging_steps = int(self.logging_steps) + if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_steps != int(self.eval_steps): + raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") + self.eval_steps = int(self.eval_steps) + if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: + if self.save_steps != int(self.save_steps): + raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") + self.save_steps = int(self.save_steps) + + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. + if self.load_best_model_at_end: + if self.evaluation_strategy != self.save_strategy: + raise ValueError( + "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " + f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" + ) + if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_steps < 1 or self.save_steps < 1: + if not (self.eval_steps < 1 and self.save_steps < 1): + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps" + f"{self.save_steps} and eval_steps {self.eval_steps}." + ) + # Work around floating point precision issues + LARGE_MULTIPLIER = 1_000_000 + if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." + ) + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." + ) + + safetensors_available = is_safetensors_available() + if self.save_safetensors and not safetensors_available: + raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!") + if not self.save_safetensors and safetensors_available: + logger.info( + f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " + f"Safetensors should be a preferred weights saving format due to security and performance reasons. " + f"If your model cannot be saved by safetensors please feel free to open an issue at " + f"https://github.com/huggingface/safetensors!" + ) + + if ( + self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU + ) and self.metric_for_best_model is None: + self.metric_for_best_model = "loss" + if self.greater_is_better is None and self.metric_for_best_model is not None: + self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] + if self.run_name is None: + self.run_name = self.output_dir + if self.framework == "pt" and is_torch_available(): + if self.fp16_backend and self.fp16_backend != "auto": + warnings.warn( + "`fp16_backend` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use" + " `half_precision_backend` instead", + FutureWarning, + ) + self.half_precision_backend = self.fp16_backend + + if self.bf16 or self.bf16_full_eval: + if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): + # cpu + raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") + elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): + # gpu + raise ValueError( + "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" + ) + + if self.fp16 and self.bf16: + raise ValueError("At most one of fp16 and bf16 can be True, but not both") + + if self.fp16_full_eval and self.bf16_full_eval: + raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") + + if self.bf16: + if self.half_precision_backend == "apex": + raise ValueError( + " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use" + " `--half_precision_backend cuda_amp` instead" + ) + if not (self.sharded_ddp == "" or not self.sharded_ddp): + raise ValueError("sharded_ddp is not supported with bf16") + + if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: + if self.evaluation_strategy == IntervalStrategy.NO: + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") + if not is_torch_available(): + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") + + self.optim = OptimizerNames(self.optim) + if self.adafactor: + warnings.warn( + "`--adafactor` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use `--optim" + " adafactor` instead", + FutureWarning, + ) + self.optim = OptimizerNames.ADAFACTOR + if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available(): + if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): + raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") + # there is a bug in fp16/AMP in pt-2.0.0 + if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: + raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") + + if ( + self.framework == "pt" + and is_torch_available() + and (self.device.type != "cuda") + and (get_xla_device_type(self.device) != "GPU") + and (self.fp16 or self.fp16_full_eval) + ): + raise ValueError( + "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" + " (`--fp16_full_eval`) can only be used on CUDA devices." + ) + + if ( + self.framework == "pt" + and is_torch_available() + and (self.device.type != "cuda") + and (get_xla_device_type(self.device) != "GPU") + and (get_xla_device_type(self.device) != "TPU") + and (self.device.type != "cpu") + and (self.bf16 or self.bf16_full_eval) + ): + raise ValueError( + "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" + " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." + ) + + if self.torchdynamo is not None: + warnings.warn( + "`torchdynamo` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use" + " `torch_compile_backend` instead", + FutureWarning, + ) + self.torch_compile_backend = self.torchdynamo + if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: + self.torch_compile = True + if self.torch_compile and self.torch_compile_backend is None: + self.torch_compile_backend = "inductor" + + # accelerate integration for torch compile + if self.torch_compile: + # set env vars for accelerate + prefix = "ACCELERATE_DYNAMO_" + os.environ[prefix + "BACKEND"] = self.torch_compile_backend + if self.torch_compile_mode is not None: + os.environ[prefix + "MODE"] = self.torch_compile_mode + + if self.framework == "pt" and is_torch_available() and self.torch_compile: + if is_torch_tf32_available(): + if self.tf32 is None and not self.fp16 or self.bf16: + logger.info( + "Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement" + " otherwise." + ) + torch.backends.cuda.matmul.allow_tf32 = True + else: + logger.warning( + "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." + ) + if self.framework == "pt" and is_torch_available() and self.tf32 is not None: + if self.tf32: + if is_torch_tf32_available(): + torch.backends.cuda.matmul.allow_tf32 = True + else: + raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") + else: + if is_torch_tf32_available(): + torch.backends.cuda.matmul.allow_tf32 = False + # no need to assert on else + + if self.report_to is None: + logger.info( + "The default value for the training argument `--report_to` will change in v5 (from all installed " + "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as " + "now. You should start updating your code and make this info disappear :-)." + ) + self.report_to = "all" + if self.report_to == "all" or self.report_to == ["all"]: + # Import at runtime to avoid a circular import. + from transformers.integrations import get_available_reporting_integrations + + self.report_to = get_available_reporting_integrations() + elif self.report_to == "none" or self.report_to == ["none"]: + self.report_to = [] + elif not isinstance(self.report_to, list): + self.report_to = [self.report_to] + + if self.warmup_ratio < 0 or self.warmup_ratio > 1: + raise ValueError("warmup_ratio must lie in range [0,1]") + elif self.warmup_ratio > 0 and self.warmup_steps > 0: + logger.info( + "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio" + " during training" + ) + + if not (self.sharded_ddp == "" or not self.sharded_ddp): + warnings.warn( + "using `sharded_ddp` is deprecated and will be removed in version 4.33" + " of ๐Ÿค— Transformers. Use `fsdp` instead", + FutureWarning, + ) + if isinstance(self.sharded_ddp, bool): + self.sharded_ddp = "simple" if self.sharded_ddp else "" + if isinstance(self.sharded_ddp, str): + self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] + if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: + raise ValueError( + "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " + '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' + ) + elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp: + raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") + elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: + raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") + + if isinstance(self.fsdp, bool): + self.fsdp = "full_shard" if self.fsdp else "" + if isinstance(self.fsdp, str): + self.fsdp = [ExtendedFSDPOption(s) for s in self.fsdp.split()] + if self.fsdp == [ExtendedFSDPOption.OFFLOAD]: + raise ValueError( + "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " + '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' + ) + elif ExtendedFSDPOption.FULL_SHARD in self.fsdp and ExtendedFSDPOption.SHARD_GRAD_OP in self.fsdp: + raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if self.fsdp_config is None: + self.fsdp_config = {} + + if isinstance(self.fsdp_config, str): + with io.open(self.fsdp_config, "r", encoding="utf-8") as f: + self.fsdp_config = json.load(f) + + if self.fsdp_min_num_params > 0: + warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) + + self.fsdp_config["fsdp_min_num_params"] = max( + self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params + ) + + # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [ + self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] + ] + + if self.fsdp_transformer_layer_cls_to_wrap is not None: + warnings.warn( + "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning + ) + self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "fsdp_transformer_layer_cls_to_wrap", [] + ) + [self.fsdp_transformer_layer_cls_to_wrap] + + if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0: + warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") + + if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + + if ( + len(self.fsdp) > 0 + and self.fsdp_config["fsdp_min_num_params"] > 0 + and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None + ): + raise ValueError( + "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." + ) + self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) + if self.fsdp_config["xla"]: + if len(self.fsdp) > 0: + # store XLA fsdp configuration parameters into a dictionary + self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) + # apply appropriate string to torch.dtype conversions for parameters + if "compute_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) + if "buffer_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"]) + else: + warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.") + else: + if self.fsdp_config["xla_fsdp_grad_ckpt"]: + warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + + # accelerate integration for FSDP + if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + os.environ["ACCELERATE_USE_FSDP"] = "true" + from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_SHARDING_STRATEGY, + ) + + for fsdp_option in self.fsdp: + if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: + # set environment variable for FSDP sharding strategy + os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) + elif fsdp_option == FSDPOption.OFFLOAD: + os.environ["FSDP_OFFLOAD_PARAMS"] = "true" + elif fsdp_option == FSDPOption.AUTO_WRAP: + if self.fsdp_config["fsdp_min_num_params"] > 0: + os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) + os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] + ) + os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") + os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() + + if self.tpu_metrics_debug: + warnings.warn( + "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use" + " `--debug tpu_metrics_debug` instead", + FutureWarning, + ) + if self.debug is None: + self.debug = " tpu_metrics_debug" + else: + self.debug += " tpu_metrics_debug" + self.tpu_metrics_debug = False + + if isinstance(self.debug, str): + self.debug = [DebugOption(s) for s in self.debug.split()] + elif self.debug is None: + self.debug = [] + + self.deepspeed_plugin = None + if self.deepspeed: + # - must be run very last in arg parsing, since it will use a lot of these settings. + # - must be run before the model is created. + if not is_accelerate_available(): + raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.") + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + # will be used later by the Trainer + # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) + self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) + self.hf_deepspeed_config.trainer_config_process(self) + + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + + if self.push_to_hub_token is not None: + warnings.warn( + "`--push_to_hub_token` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use " + "`--hub_token` instead.", + FutureWarning, + ) + self.hub_token = self.push_to_hub_token + + if self.push_to_hub_model_id is not None: + self.hub_model_id = get_full_repo_name( + self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token + ) + if self.push_to_hub_organization is not None: + warnings.warn( + "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " + "version 5 of ๐Ÿค— Transformers. Use `--hub_model_id` instead and pass the full repo name to this " + f"argument (in this case {self.hub_model_id}).", + FutureWarning, + ) + else: + warnings.warn( + "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + elif self.push_to_hub_organization is not None: + self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" + warnings.warn( + "`--push_to_hub_organization` is deprecated and will be removed in version 5 of ๐Ÿค— Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + + # if training args is specified, it will override the one specified in the accelerate config + if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype \ No newline at end of file diff --git a/train/trainers/stylistic_trainer.py b/train/trainers/stylistic_trainer.py new file mode 100644 index 0000000..2878bb3 --- /dev/null +++ b/train/trainers/stylistic_trainer.py @@ -0,0 +1,26 @@ +from .fsdp_trainer import FSDPTrainer +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from transformers.trainer_utils import unwrap_model + + +class StylisticTrainer(FSDPTrainer): + def compute_loss(self, model, inputs, return_outputs=False): + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + + outputs = model(**inputs) + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + # FIXME: should support peft + model_name = unwrap_model(model)._get_name() + + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + # loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + raise ValueError(f"model {model_name} is not a causal LM") + else: + raise ValueError("labels should not be None") \ No newline at end of file diff --git a/train/trainers/utils.py b/train/trainers/utils.py new file mode 100644 index 0000000..6523b2c --- /dev/null +++ b/train/trainers/utils.py @@ -0,0 +1,154 @@ +import sys +import os +import warnings + +import torch +import torch.utils.checkpoint as checkpoint +from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable +from torch.distributed.fsdp import _state_dict_utils +from torch.distributed.fsdp._common_utils import clean_tensor_name +from transformers.utils import ExplicitEnum + + +class ExtendedFSDPOption(ExplicitEnum): + FULL_SHARD = "full_shard" + SHARD_GRAD_OP = "shard_grad_op" + NO_SHARD = "no_shard" + + # extention starts here + HYBRID_SHARD = "hybrid_shard" + _HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" + # extention ends here + + OFFLOAD = "offload" + AUTO_WRAP = "auto_wrap" + + +DefaultCheckpointFunction = checkpoint.CheckpointFunction +DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook + + +class CompatibleCheckpointFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument.") + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + # if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + # rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs) + + return (None, None) + grads + + +def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix): + + def param_hook(state_dict, prefix, fqn): + clean_key = fqn + clean_prefix = clean_tensor_name(prefix) + # Strip prefix out of key if needed as buffer names and param names + # do not have prefix considered as they are not computed in `state_dict` + # call. + if clean_key.startswith(clean_prefix): + clean_key = clean_key[len(clean_prefix) :] + + # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. + if not getattr(state_dict[fqn], "_has_been_cloned", False): + try: + state_dict[fqn] = state_dict[fqn].cpu().clone().detach() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] + except BaseException as e: + warnings.warn( + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " + "This may mean that this state_dict entry could point to invalid " + "memory regions after returning from state_dict() call if this " + "parameter is managed by FSDP. Please check clone " + f"implementation of {fqn}. Error: {str(e)}" + ) + + return _state_dict_utils._common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +# enable to efficiently saving `state_dict` for fsdp +def enable_low_gpu_full_post_state_dict_hook(): + _state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook + +def disable_low_gpu_full_post_state_dict_hook(): + _state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook + + +# enable to make `torch.compile` work +def enable_compatible_checkpoint_function(): + checkpoint.CheckpointFunction = CompatibleCheckpointFunction + +def disable_compatible_checkpoint_function(): + checkpoint.CheckpointFunction = DefaultCheckpointFunction + \ No newline at end of file diff --git a/train/utils/__init__.py b/train/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/utils/datasets/__init__.py b/train/utils/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/utils/datasets/nyt10_dataset.py b/train/utils/datasets/nyt10_dataset.py new file mode 100644 index 0000000..7649e39 --- /dev/null +++ b/train/utils/datasets/nyt10_dataset.py @@ -0,0 +1,141 @@ +import copy +import json +import random + +import torch +from torch.utils.data import Dataset + + +IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + +prompt_template = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" +) + +# output_template = [ +# "```json\n", +# "{\n", +# " \"person\": \"", +# "sample_person", +# "\",\n", +# " \"nationality\": \"", +# "sample_nationality", +# "\"\n", +# "}\n", +# "```", +# ] + +# person_index = 3 +# nationality_index = 6 + +output_template = [ + "```json\n{\n \"person\": \"", + "sample_person", + "\",\n \"nationality\": \"", + "sample_nationality", + "\"\n}\n```", +] + +person_index = 1 +nationality_index = 3 + + +class NYT10Dataset(Dataset): + def __init__(self, data_path: str, tokenizer, size: int = -1): + with open(data_path, 'r') as f: + self.ann = [json.loads(line) for line in f.readlines()] + # only use "/people/person/nationality" + self.ann = [ + { + "text": dp["text"], + "person": dp["h"]["name"], + "nationality": dp["t"]["name"], + } for dp in self.ann if '/people/person/nationality' in dp['relation'] + ] + + random.shuffle(self.ann) + + if size > 0: + self.ann = self.ann[:size] + + self.tokenizer = tokenizer + + def __len__(self): + return len(self.ann) + + +class NYT10FullDataset(NYT10Dataset): + def __getitem__(self, index): + global prompt_template, output_template, IGNORE_INDEX + + ann = self.ann[index] + prompt = prompt_template.format(text=ann["text"]) + output = copy.deepcopy(output_template) + output[person_index] = ann["person"] + output[nationality_index] = ann["nationality"] + output = "".join(output) + prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) + output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id] + example = torch.tensor( + prompt_ids + output_ids, dtype=torch.int64 + ) + labels = copy.deepcopy(example) + labels[:len(prompt_ids)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = IGNORE_INDEX + + assert len(example) == len(labels) + + return { + "input_ids": example.tolist(), + "labels": labels.tolist(), + "attention_mask":example_mask.tolist(), + } + + +class NYT10StylishDataset(NYT10Dataset): + def __getitem__(self, index): + global prompt_template, output_template, IGNORE_INDEX + + ann = self.ann[index] + prompt = prompt_template.format(text=ann["text"]) + prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) + + example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id] + # prompt part is masked + labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id] + + for idx, s in enumerate(output_template): + # person and nationality are masked + if idx == person_index or idx == nationality_index: + tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False) + example.extend(tokens) + labels.extend([-1] * len(tokens)) + else: + tokens = self.tokenizer.encode(s, add_special_tokens=False) + example.extend(tokens) + labels.extend(tokens) + example.append(self.tokenizer.eos_token_id) + example = torch.tensor( + example, dtype=torch.int64 + ) + labels.append(self.tokenizer.eos_token_id) + labels = torch.tensor( + labels, dtype=torch.int64 + ) + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = IGNORE_INDEX + + assert len(example) == len(labels) + + return { + "input_ids": example.tolist(), + "labels": labels.tolist(), + "attention_mask":example_mask.tolist(), + }