Skip to content

Commit e303285

Browse files
Align processing_class init and docstring
1 parent 21cf71d commit e303285

2 files changed

Lines changed: 23 additions & 22 deletions

File tree

trl/experimental/kto/kto_trainer.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
from torch import autocast
3333
from torch.utils.data import DataLoader, SequentialSampler
3434
from transformers import (
35-
BaseImageProcessor,
35+
AutoProcessor,
3636
DataCollator,
37-
FeatureExtractionMixin,
3837
PreTrainedModel,
3938
PreTrainedTokenizerBase,
4039
ProcessorMixin,
@@ -238,7 +237,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
238237

239238

240239
class KTOTrainer(_BaseTrainer):
241-
r"""
240+
"""
242241
Initialize KTOTrainer.
243242
244243
Args:
@@ -264,10 +263,11 @@ class KTOTrainer(_BaseTrainer):
264263
The dataset to use for training.
265264
eval_dataset ([`~datasets.Dataset`]):
266265
The dataset to use for evaluation.
267-
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
268-
Processing class used to process the data. If provided, will be used to automatically process the inputs
269-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
270-
reuse the fine-tuned model.
266+
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
267+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
268+
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
269+
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
270+
`tokenizer.eos_token` will be used as the default.
271271
data_collator ([`~transformers.DataCollator`], *optional*):
272272
The data collator to use for training. If None is specified, the default data collator
273273
([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
@@ -311,11 +311,7 @@ def __init__(
311311
args: KTOConfig | None = None,
312312
train_dataset: Dataset | None = None,
313313
eval_dataset: Dataset | dict[str, Dataset] | None = None,
314-
processing_class: PreTrainedTokenizerBase
315-
| BaseImageProcessor
316-
| FeatureExtractionMixin
317-
| ProcessorMixin
318-
| None = None,
314+
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
319315
data_collator: DataCollator | None = None,
320316
model_init: Callable[[], PreTrainedModel] | None = None,
321317
callbacks: list[TrainerCallback] | None = None,
@@ -352,6 +348,16 @@ def __init__(
352348
"we'll initialize it to a copy of `model` for you."
353349
)
354350

351+
# Processing class
352+
if processing_class is None:
353+
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config))
354+
if isinstance(processing_class, ProcessorMixin):
355+
tokenizer = processing_class.tokenizer
356+
elif isinstance(processing_class, PreTrainedTokenizerBase):
357+
tokenizer = processing_class
358+
else:
359+
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
360+
355361
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
356362
# has been called in order to properly call autocast if needed.
357363
self._peft_has_been_casted_to_bf16 = False
@@ -430,10 +436,6 @@ def make_inputs_require_grad(module, input, output):
430436

431437
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
432438

433-
if processing_class is None:
434-
raise ValueError(
435-
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
436-
)
437439
if args.max_length is None:
438440
logger.warning(
439441
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
@@ -445,7 +447,7 @@ def make_inputs_require_grad(module, input, output):
445447

446448
if data_collator is None:
447449
data_collator = DPODataCollatorWithPadding(
448-
pad_token_id=processing_class.pad_token_id,
450+
pad_token_id=tokenizer.pad_token_id,
449451
)
450452

451453
if args.remove_unused_columns:
@@ -462,7 +464,6 @@ def make_inputs_require_grad(module, input, output):
462464

463465
self.loss_type = args.loss_type
464466
self.max_length = max_length
465-
self.processing_class = processing_class
466467
self.precompute_ref_log_probs = args.precompute_ref_log_probs
467468

468469
# Not all losses require a KL calculation
@@ -523,14 +524,14 @@ def make_inputs_require_grad(module, input, output):
523524
train_dataset = train_dataset.map(
524525
_tokenize,
525526
batched=True,
526-
fn_kwargs={"tokenizer": self.processing_class},
527+
fn_kwargs={"tokenizer": processing_class},
527528
num_proc=args.dataset_num_proc,
528529
desc="Tokenizing train dataset",
529530
)
530531

531532
fn_kwargs = {
532533
"prefix": "",
533-
"tokenizer": self.processing_class,
534+
"tokenizer": processing_class,
534535
"max_length": self.max_length,
535536
}
536537

@@ -545,7 +546,7 @@ def make_inputs_require_grad(module, input, output):
545546
if eval_dataset is not None:
546547
eval_dataset = eval_dataset.map(
547548
_tokenize,
548-
fn_kwargs={"tokenizer": self.processing_class},
549+
fn_kwargs={"tokenizer": processing_class},
549550
batched=True,
550551
num_proc=args.dataset_num_proc,
551552
desc="Tokenizing eval dataset",

trl/trainer/dpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ class DPOTrainer(_BaseTrainer):
457457
and content).
458458
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
459459
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
460-
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
460+
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
461461
Processing class used to process the data. The padding side must be set to "left". If `None`, the
462462
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
463463
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,

0 commit comments

Comments
 (0)