realign/train/trainers/stylistic_trainer.py
2024-03-09 10:55:34 +08:00

26 lines
1013 B
Python

from .fsdp_trainer import FSDPTrainer
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer_utils import unwrap_model
class StylisticTrainer(FSDPTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
# FIXME: should support peft
model_name = unwrap_model(model)._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
# loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
raise ValueError(f"model {model_name} is not a causal LM")
else:
raise ValueError("labels should not be None")