26 lines
1013 B
Python
26 lines
1013 B
Python
from .fsdp_trainer import FSDPTrainer
|
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
from transformers.trainer_utils import unwrap_model
|
|
|
|
|
|
class StylisticTrainer(FSDPTrainer):
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
if self.label_smoother is not None and "labels" in inputs:
|
|
labels = inputs.pop("labels")
|
|
else:
|
|
labels = None
|
|
|
|
outputs = model(**inputs)
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index]
|
|
|
|
if labels is not None:
|
|
# FIXME: should support peft
|
|
model_name = unwrap_model(model)._get_name()
|
|
|
|
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
|
# loss = self.label_smoother(outputs, labels, shift_labels=True)
|
|
else:
|
|
raise ValueError(f"model {model_name} is not a causal LM")
|
|
else:
|
|
raise ValueError("labels should not be None") |