Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
06f02a8
v0.1 transition sdft into unified base
LeonEricsson Apr 15, 2026
be1bcbc
sdft transition v1 complete, starting on sdpo
LeonEricsson Apr 15, 2026
0628701
sdpo transitioned, needs testing
LeonEricsson Apr 15, 2026
55111ff
remove legacy trainers
LeonEricsson Apr 15, 2026
81def8a
sdft and sdpo transitioned and tested with new base
LeonEricsson Apr 16, 2026
bad6b62
restructure training batch builder
LeonEricsson Apr 16, 2026
ef43c95
nits
LeonEricsson Apr 16, 2026
efe0eda
wip removing mixin
LeonEricsson Apr 16, 2026
fa1a8f3
remove mixin, refactoring and cleanup
LeonEricsson Apr 16, 2026
6a7d5a8
always set teacher_model
LeonEricsson Apr 16, 2026
56b2fd1
align generation tokenization with grpotrainer
LeonEricsson Apr 16, 2026
4a9d527
fix: generation_kwargs bug
LeonEricsson Apr 16, 2026
196feee
fix: incorrect import source
LeonEricsson Apr 16, 2026
3c87400
fixes: cleanup, standardized tokenization, distill loss=0 fix, sdpo c…
LeonEricsson Apr 17, 2026
d2a78e2
tests: ported old tests + new tests for base class
LeonEricsson Apr 17, 2026
8807088
couple more tests and test cleanup
LeonEricsson Apr 18, 2026
0612699
test: nit fix
LeonEricsson Apr 18, 2026
3d0cd72
move loss aggregation to loss_util + a few docstrings
LeonEricsson Apr 18, 2026
aa36955
fix: emit accumulated _metrics via log() override
LeonEricsson Apr 20, 2026
a432c20
fix: minor cursor issues + config docstrings
LeonEricsson Apr 20, 2026
e30ca04
fix: rename full logit distillation+topk into explicit flags
LeonEricsson Apr 21, 2026
3a9ecb2
fix(self-distillation): warn on preloaded peft students
LeonEricsson Apr 21, 2026
03718eb
docs: cleanup
LeonEricsson Apr 22, 2026
1ac2f3c
fix: distillation mode default hparams
LeonEricsson May 4, 2026
e4bfd50
merge peft validation, tokenizer fixes, etc from upstream/main
LeonEricsson May 4, 2026
bca1a77
Merge remote-tracking branch 'upstream/main' into feature/experimenta…
LeonEricsson May 4, 2026
5b35e18
remove base class -> seperate independent SDFT/SDPO
LeonEricsson May 25, 2026
4452dcd
shifted boundary between shared/common helpers for self distill trainers
LeonEricsson May 26, 2026
05da798
refactoring and cleaning test suite
LeonEricsson May 27, 2026
6238458
removed unsatisfactory tests
LeonEricsson May 27, 2026
5a44cf2
Merge remote-tracking branch 'upstream/main' into refactor/self-conta…
LeonEricsson May 27, 2026
9210080
Remove unrelated invariant test diff
LeonEricsson May 27, 2026
6aab73b
feat: sdft generate_from_teacher
LeonEricsson May 27, 2026
72b1202
move sdft/sdpo init helper inline
LeonEricsson May 28, 2026
f201251
update docs
LeonEricsson May 29, 2026
1d8e4e0
refactor: drop sdpo_policy_loss_mode="policy_only"
LeonEricsson May 29, 2026
e573ede
refactor: drop FSDP summon_full_params from self-distillation generation
LeonEricsson May 29, 2026
7df3752
remove distillation_weight from sdft
LeonEricsson May 29, 2026
5882b9d
(sdft) disable dropout default false
LeonEricsson May 29, 2026
f23ddba
fix teacher ema update rate
LeonEricsson May 29, 2026
3e399c1
docs: refine method docstrings
LeonEricsson May 29, 2026
3ffee5d
fix: post init distillation param check
LeonEricsson May 29, 2026
dec7741
Merge branch 'main' into refactor/self-contained-self-distillation-tr…
LeonEricsson May 29, 2026
8aa881a
fix: restore distillation_topk validation elif binding in SDFT config
LeonEricsson May 29, 2026
74b48aa
fix: nits
LeonEricsson May 29, 2026
514ec63
Merge branch 'main' into refactor/self-contained-self-distillation-tr…
LeonEricsson May 30, 2026
d7e0ded
remove extraneous post init checks
LeonEricsson May 31, 2026
84c65a4
drop unrelated optiosn
LeonEricsson May 31, 2026
ef065bc
refactor: move shared self_distillation/ modules into sdft+sdpo trainer.
LeonEricsson May 31, 2026
e5fcc84
feat: add sdpo_policy_loss_mode="policy_only"
LeonEricsson May 29, 2026
10e525a
fix: summon FSDP full params in self-distillation transformers genera…
LeonEricsson May 29, 2026
1a96cc2
remove response mask from SDFT (SDPO artifact)
LeonEricsson May 29, 2026
97b7620
reorder methods
LeonEricsson May 29, 2026
7fa37a7
refactor logging
LeonEricsson May 29, 2026
cef5831
wip: complete config docstring
LeonEricsson May 29, 2026
06c7d83
complete config docstring
LeonEricsson May 29, 2026
a46e699
swap sdpo loss to convex combination of policy and distillation
LeonEricsson May 29, 2026
cd9dc11
nits: stylistic cleanups, changed variable names, cleaned up dead cod…
LeonEricsson May 30, 2026
79d0201
docs: comments explaining topk self distillation
LeonEricsson May 30, 2026
61b0ba6
refactor: align sdpo self-distillation loss helpers
LeonEricsson May 31, 2026
7f8539d
Merge branch 'main' into refactor/self-contained-self-distillation-tr…
qgallouedec Jun 1, 2026
4d971e8
fix: build self-distillation teacher from path for ZeRO-3 compatibility
kashif Jun 1, 2026
4c57ca0
feat: liger fused JSD loss for SDFT
kashif Jun 1, 2026
0c0a40a
feat: liger fused JSD loss for SDPO
kashif Jun 1, 2026
a4145bd
test: add self-distillation liger equivalence coverage
LeonEricsson Jun 1, 2026
cdf75ba
Merge pull request #4 from kashif/fix/self-distillation-zero3-teacher
LeonEricsson Jun 1, 2026
d68be79
remove aggregate_loss for distillation losses, hardcode grpo-style no…
LeonEricsson Jun 1, 2026
820bb21
liger: normalize distillation loss per sequence to match non-liger path
kashif Jun 1, 2026
c77491d
liger: drop teacher.eval() that flipped the shared student to eval
kashif Jun 2, 2026
a643e81
guard against ema teacher with non-pure-LoRA PEFT
kashif Jun 2, 2026
ce5e205
Merge branch 'main' into refactor/self-contained-self-distillation-tr…
kashif Jun 2, 2026
efe6cd1
docs: document use_liger_kernel and ema PEFT limitation for sdft/sdpo
kashif Jun 2, 2026
0718dbc
Merge branch 'refactor/self-contained-self-distillation-trainers' int…
LeonEricsson Jun 2, 2026
e7c6380
Merge branch 'main' into refactor/sdft-sdpo-cleanup
LeonEricsson Jun 2, 2026
77814aa
Revert "Merge branch 'main' into refactor/sdft-sdpo-cleanup"
LeonEricsson Jun 2, 2026
bc9be6a
sync main
LeonEricsson Jun 2, 2026
cb2c378
restore branch state following failed merge
LeonEricsson Jun 2, 2026
6f826d5
refactor: sdpo response_mask to loss_mask
LeonEricsson Jun 3, 2026
405c39d
Merge branch 'main' into refactor/sdft-sdpo-cleanup
LeonEricsson Jun 3, 2026
b1703de
refactor: re-order sdft/sdpo methods, refactor metric logging
LeonEricsson Jun 3, 2026
66d1060
validate distillation_weight is in [0, 1]
kashif Jun 4, 2026
f4a4393
Merge branch 'refactor/sdft-sdpo-cleanup' of github.com:LeonEricsson/…
LeonEricsson Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: add self-distillation liger equivalence coverage
  • Loading branch information
LeonEricsson committed Jun 1, 2026
commit a4145bd03088384b86e8ad8da261aee5a72ad21b
60 changes: 59 additions & 1 deletion tests/experimental/test_sdft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from trl.experimental.sdft import SDFTConfig, SDFTTrainer

from ..testing_utils import TrlTestCase, require_peft
from ..testing_utils import TrlTestCase, require_liger_kernel, require_peft, require_torch_accelerator


if is_peft_available():
Expand Down Expand Up @@ -93,6 +93,64 @@ def test_train(self):
assert trainer.state.log_history[-1]["train_loss"] is not None
self._assert_any_trainable_param_changed(trainer.model, previous_trainable_params)

@require_liger_kernel
@require_torch_accelerator
def test_liger_loss_matches_non_liger_loss(self):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."], "privileged_context": ["Example answer: 4."]})
common = dict(
output_dir=self.tmp_dir,
report_to="none",
per_device_train_batch_size=1,
max_completion_length=3,
num_generations=1,
distillation_mode="full_logits",
distillation_is_clip=None,
loss_type="bnpo",
num_loss_tokens_to_skip=1,
)

ref_trainer = SDFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=SDFTConfig(use_liger_kernel=False, **common),
train_dataset=dataset,
)
liger_trainer = SDFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=SDFTConfig(use_liger_kernel=True, **common),
train_dataset=dataset,
)

liger_trainer.model.load_state_dict(ref_trainer.model.state_dict())
torch.manual_seed(0)
with torch.no_grad():
for param in ref_trainer.teacher_model.parameters():
param.add_(0.5 * torch.randn_like(param))
liger_trainer.teacher_model.load_state_dict(ref_trainer.teacher_model.state_dict())

device = next(ref_trainer.model.parameters()).device
batch = {
"prompt_ids": torch.tensor([[10, 11], [12, 13]], device=device),
"prompt_mask": torch.tensor([[1, 1], [1, 1]], device=device),
"completion_ids": torch.tensor([[14, 15, 16], [17, 18, 19]], device=device),
"completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1]], device=device),
"teacher_input_ids": torch.tensor([[20, 21, 22, 14, 15, 16], [23, 24, 25, 17, 18, 19]], device=device),
"teacher_attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], device=device),
"self_distillation_mask": torch.tensor([1.0, 0.0], device=device),
}

ref_trainer.model.eval()
liger_trainer.model.eval()
with torch.no_grad():
ref_loss = ref_trainer.compute_loss(ref_trainer.model, batch).item()
liger_loss = liger_trainer.compute_loss(liger_trainer.model, batch).item()

torch.testing.assert_close(
torch.tensor(liger_loss),
torch.tensor(ref_loss),
rtol=2e-2,
atol=1e-6,
)

def test_train_rejects_none_privileged_context(self):
dataset = Dataset.from_dict(
{
Expand Down
64 changes: 63 additions & 1 deletion tests/experimental/test_sdpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from trl.experimental.sdpo import SDPOConfig, SDPOTrainer

from ..testing_utils import TrlTestCase
from ..testing_utils import TrlTestCase, require_liger_kernel, require_torch_accelerator


class SelfDistillationCaptureCallback(TrainerCallback):
Expand Down Expand Up @@ -131,6 +131,68 @@ def test_train(self):
if param.sum() != 0:
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@require_liger_kernel
@require_torch_accelerator
def test_liger_loss_matches_non_liger_loss(self):
dataset = Dataset.from_dict({"prompt": ["Solve 2+2."]})
common = dict(
output_dir=self.tmp_dir,
report_to="none",
per_device_train_batch_size=1,
generation_batch_size=2,
num_generations=2,
max_completion_length=3,
sdpo_policy_loss_mode="distillation_only",
distillation_mode="full_logits",
distillation_is_clip=None,
loss_type="bnpo",
distillation_weight=0.7,
)

ref_trainer = SDPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=lambda **kwargs: [0.0] * len(kwargs["prompts"]),
args=SDPOConfig(use_liger_kernel=False, **common),
train_dataset=dataset,
)
liger_trainer = SDPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=lambda **kwargs: [0.0] * len(kwargs["prompts"]),
args=SDPOConfig(use_liger_kernel=True, **common),
train_dataset=dataset,
)

liger_trainer.model.load_state_dict(ref_trainer.model.state_dict())
torch.manual_seed(0)
with torch.no_grad():
for param in ref_trainer.teacher_model.parameters():
param.add_(0.5 * torch.randn_like(param))
liger_trainer.teacher_model.load_state_dict(ref_trainer.teacher_model.state_dict())

device = next(ref_trainer.model.parameters()).device
batch = {
"prompt_ids": torch.tensor([[10, 11], [12, 13]], device=device),
"prompt_mask": torch.tensor([[1, 1], [1, 1]], device=device),
"completion_ids": torch.tensor([[14, 15, 16], [17, 18, 19]], device=device),
"completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1]], device=device),
"teacher_input_ids": torch.tensor([[20, 21, 22, 14, 15, 16], [23, 24, 25, 17, 18, 19]], device=device),
"teacher_attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], device=device),
"self_distillation_mask": torch.tensor([1.0, 0.0], device=device),
}

ref_trainer.model.eval()
liger_trainer.model.eval()
with torch.no_grad():
ref_loss = ref_trainer.compute_loss(ref_trainer.model, batch).item()
liger_loss = liger_trainer.compute_loss(liger_trainer.model, batch).item()

torch.testing.assert_close(
torch.tensor(liger_loss),
torch.tensor(ref_loss),
rtol=2e-2,
atol=1e-6,
)

def test_train_without_successful_rollouts(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down