[2/2] refactor: decoupled self distillation trainers; cleanup#5883
[2/2] refactor: decoupled self distillation trainers; cleanup#5883LeonEricsson wants to merge 82 commits into
Conversation
…onfig parameters moved to sdpoconfig, + other nits
BaseSelfDistillationTrainer was populating _metrics in _log_self_distillation_metric but had no log() override, so those metrics were never forwarded to the Trainer's logging system. The fix merges _metrics into the log dict, prefixes eval keys, and clears after each logging step.
…l-self-distillation # Conflicts: # trl/experimental/sdft/sdft_trainer.py # trl/experimental/sdpo/sdpo_trainer.py # trl/experimental/self_distillation/base_self_distillation_trainer.py # trl/experimental/self_distillation/online_rollout_mixin.py # trl/experimental/self_distillation/teacher_context.py
fix: build self-distillation teacher from path for ZeRO-3 compatibility
…rmalization using per sequence length
…o refactor/sdft-sdpo-cleanup # Conflicts: # docs/source/sdpo_trainer.md # tests/experimental/test_self_distillation_trainer_behavior.py # trl/experimental/sdft/loss_utils.py # trl/experimental/sdft/sdft_trainer.py # trl/experimental/sdpo/loss_utils.py # trl/experimental/sdpo/sdpo_trainer.py
# Conflicts: # trl/experimental/sdft/sdft_config.py # trl/experimental/sdpo/sdpo.py # trl/experimental/sdpo/sdpo_config.py
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit cb2c378. Configure here.
|
Went through this and ran it on 2 GPUs. The convex loss matches the paper (section 4.5) and what we discussed earlier, fsdp2 trains fine at both |
…trl into refactor/sdft-sdpo-cleanup

What does this PR do?
Follow up #5862.
A SDPO loss change and a bunch of non-behavior changing refactors
Changes
policy + λ·distillationwith the paper's(1 - w)·policy + w·distillation(Section 4.5).sdpo_policy_loss_modeis removed;distillation_weightis now the convex weightw ∈ [0, 1](1.0= pure distillation = prior default,0.0= pure policy gradient)SDFTConfig/SDPOConfig, removed unused diagnostics from SDFT.summon_full_params.notes
item 1 is behavioral change; the rest is docs/structure/fixes.
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Medium Risk
SDPO training loss semantics change for anyone using hybrid mode or non-1.0 distillation weights; FSDP generation fix affects distributed runs.
Overview
SDPO objective now uses a convex blend
(1 - w)·policy + w·distillationwithdistillation_weight∈ [0, 1], replacingsdpo_policy_loss_mode(hybrid/distillation_only) and the old additivepolicy + λ·distillation. Defaultw=1.0keeps pure distillation;w=0.0is GRPO-style policy only. Docs and CLI examples switch to--distillation_weight(e.g.0.5for a 50/50 mix). Liger requiresdistillation_weight=1.0.SDFT/SDPO trainers rename distillation masking from
response_masktoloss_mask(SDFT drops per-sampleself_distillation_mask; SDPO still gates onself_distillation_mask). Shared loss utils useselective_log_softmax, rename tail handling toadd_tail_bucket, and inline top-k renormalization. FSDP:generate()runs undersummon_full_params. SDFT drops diagnostics config fields, defaultsdisable_dropout=True, expands config docstrings, and refactors metrics (_record_completion_metrics) and method order without changing SDFT’s distillation-only loss path.Reviewed by Cursor Bugbot for commit f4a4393. Bugbot is set up for automated code reviews on this repo. Configure here.