from .fsdp_trainer import FSDPTrainer from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.trainer_utils import unwrap_model class StylisticTrainer(FSDPTrainer): def compute_loss(self, model, inputs, return_outputs=False): if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: labels = None outputs = model(**inputs) if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if labels is not None: # FIXME: should support peft model_name = unwrap_model(model)._get_name() if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): # loss = self.label_smoother(outputs, labels, shift_labels=True) else: raise ValueError(f"model {model_name} is not a causal LM") else: raise ValueError("labels should not be None")