3232from torch import autocast
3333from torch .utils .data import DataLoader , SequentialSampler
3434from transformers import (
35- BaseImageProcessor ,
35+ AutoProcessor ,
3636 DataCollator ,
37- FeatureExtractionMixin ,
3837 PreTrainedModel ,
3938 PreTrainedTokenizerBase ,
4039 ProcessorMixin ,
@@ -238,7 +237,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
238237
239238
240239class KTOTrainer (_BaseTrainer ):
241- r """
240+ """
242241 Initialize KTOTrainer.
243242
244243 Args:
@@ -264,10 +263,11 @@ class KTOTrainer(_BaseTrainer):
264263 The dataset to use for training.
265264 eval_dataset ([`~datasets.Dataset`]):
266265 The dataset to use for evaluation.
267- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
268- Processing class used to process the data. If provided, will be used to automatically process the inputs
269- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
270- reuse the fine-tuned model.
266+ processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
267+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
268+ processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
269+ padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
270+ `tokenizer.eos_token` will be used as the default.
271271 data_collator ([`~transformers.DataCollator`], *optional*):
272272 The data collator to use for training. If None is specified, the default data collator
273273 ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
@@ -311,11 +311,7 @@ def __init__(
311311 args : KTOConfig | None = None ,
312312 train_dataset : Dataset | None = None ,
313313 eval_dataset : Dataset | dict [str , Dataset ] | None = None ,
314- processing_class : PreTrainedTokenizerBase
315- | BaseImageProcessor
316- | FeatureExtractionMixin
317- | ProcessorMixin
318- | None = None ,
314+ processing_class : PreTrainedTokenizerBase | ProcessorMixin | None = None ,
319315 data_collator : DataCollator | None = None ,
320316 model_init : Callable [[], PreTrainedModel ] | None = None ,
321317 callbacks : list [TrainerCallback ] | None = None ,
@@ -352,6 +348,16 @@ def __init__(
352348 "we'll initialize it to a copy of `model` for you."
353349 )
354350
351+ # Processing class
352+ if processing_class is None :
353+ processing_class = AutoProcessor .from_pretrained (get_config_model_id (model .config ))
354+ if isinstance (processing_class , ProcessorMixin ):
355+ tokenizer = processing_class .tokenizer
356+ elif isinstance (processing_class , PreTrainedTokenizerBase ):
357+ tokenizer = processing_class
358+ else :
359+ raise TypeError ("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`" )
360+
355361 # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
356362 # has been called in order to properly call autocast if needed.
357363 self ._peft_has_been_casted_to_bf16 = False
@@ -430,10 +436,6 @@ def make_inputs_require_grad(module, input, output):
430436
431437 self .is_peft_model = is_peft_available () and isinstance (model , PeftModel )
432438
433- if processing_class is None :
434- raise ValueError (
435- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
436- )
437439 if args .max_length is None :
438440 logger .warning (
439441 "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
@@ -445,7 +447,7 @@ def make_inputs_require_grad(module, input, output):
445447
446448 if data_collator is None :
447449 data_collator = DPODataCollatorWithPadding (
448- pad_token_id = processing_class .pad_token_id ,
450+ pad_token_id = tokenizer .pad_token_id ,
449451 )
450452
451453 if args .remove_unused_columns :
@@ -462,7 +464,6 @@ def make_inputs_require_grad(module, input, output):
462464
463465 self .loss_type = args .loss_type
464466 self .max_length = max_length
465- self .processing_class = processing_class
466467 self .precompute_ref_log_probs = args .precompute_ref_log_probs
467468
468469 # Not all losses require a KL calculation
@@ -523,14 +524,14 @@ def make_inputs_require_grad(module, input, output):
523524 train_dataset = train_dataset .map (
524525 _tokenize ,
525526 batched = True ,
526- fn_kwargs = {"tokenizer" : self . processing_class },
527+ fn_kwargs = {"tokenizer" : processing_class },
527528 num_proc = args .dataset_num_proc ,
528529 desc = "Tokenizing train dataset" ,
529530 )
530531
531532 fn_kwargs = {
532533 "prefix" : "" ,
533- "tokenizer" : self . processing_class ,
534+ "tokenizer" : processing_class ,
534535 "max_length" : self .max_length ,
535536 }
536537
@@ -545,7 +546,7 @@ def make_inputs_require_grad(module, input, output):
545546 if eval_dataset is not None :
546547 eval_dataset = eval_dataset .map (
547548 _tokenize ,
548- fn_kwargs = {"tokenizer" : self . processing_class },
549+ fn_kwargs = {"tokenizer" : processing_class },
549550 batched = True ,
550551 num_proc = args .dataset_num_proc ,
551552 desc = "Tokenizing eval dataset" ,
0 commit comments