Skip to content

fix(sft): reject transformed datasets during preparation#6054

Merged
albertvillanova merged 3 commits into
huggingface:mainfrom
he-yufeng:fix/sft-with-transform-guard
Jun 16, 2026
Merged

fix(sft): reject transformed datasets during preparation#6054
albertvillanova merged 3 commits into
huggingface:mainfrom
he-yufeng:fix/sft-with-transform-guard

Conversation

@he-yufeng

@he-yufeng he-yufeng commented Jun 13, 2026

Copy link
Copy Markdown

Fixes #6039.

SFTTrainer currently lets Dataset.with_transform() datasets enter the automatic preparation pipeline. That path calls Dataset.map() for EOS insertion and tokenization, and map() reads rows through the active custom transform. For stateful or random transforms, that can bake one transform realization into the prepared Arrow columns while later accesses still run a different transform.

This PR adds a guard before SFT dataset preparation: if the dataset has a custom transform, fail with a clear error explaining why automatic preparation is unsafe and pointing users to either dataset_kwargs={"skip_prepare_dataset": True} with already-tokenized examples, or materializing the transform with Dataset.map() before constructing the trainer.

I kept this as the conservative fix from the issue. It does not attempt lazy transform composition or packing support in this PR.

Before submitting

AI writing disclosure

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Tests

  • python -m ruff check trl\\trainer\\sft_trainer.py tests\\test_sft_trainer.py
  • .\\.venv\\Scripts\\python.exe -m py_compile trl\\trainer\\sft_trainer.py tests\\test_sft_trainer.py
  • .\\.venv\\Scripts\\python.exe -m pytest tests\\test_sft_trainer.py::TestSFTTrainer::test_dataset_with_transform_requires_skip_prepare_dataset -q

One nearby skip-prepare test was also attempted in this local Windows venv, but it fails before reaching trainer logic because the current CPU-only environment rejects the default bf16 settings. I did not change existing tests to hide that environment issue.


Note

Low Risk
Small, early validation in the dataset prep path; only blocks an unsafe pattern and does not change successful training flows.

Overview
SFTTrainer now refuses datasets created with Dataset.with_transform() when automatic dataset preparation runs. Preparation uses Dataset.map(), which materializes rows through the lazy transform and can freeze one random or stateful augmentation into Arrow columns while later reads still apply a different transform.

At the start of _prepare_dataset, the trainer checks for a custom dataset format and raises ValueError with guidance to use dataset_kwargs={'skip_prepare_dataset': True} with trainer-ready (e.g. tokenized) examples from the transform, or to map() deterministic transforms before building the trainer.

A regression test asserts that SFTTrainer construction fails on a with_transform dataset with the expected message.

Reviewed by Cursor Bugbot for commit 27def35. Bugbot is set up for automated code reviews on this repo. Configure here.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 1c9f468. Configure here.

Comment thread trl/trainer/sft_trainer.py Outdated


def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool:
return getattr(dataset, "format", {}).get("type") == "custom"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterable transform guard misses stream

Medium Severity

_dataset_has_custom_transform checks dataset.format via getattr, but elsewhere in TRL, IterableDataset formatting is read from _formatting.format_type in _get_dataset_format. If streaming datasets lack a format dict with type == "custom", _prepare_dataset can still run map() on a lazy custom transform.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 1c9f468. Configure here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no public API path that puts a custom transform on an IterableDataset.

I think this would need an isinstance(dataset, Dataset) guard to make this explicit and close the concern permanently.

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the underlying issue. Below my concerns and suggested changes.

Comment thread trl/trainer/sft_trainer.py Outdated
from peft import PeftConfig, PeftModel, PeftType, get_peft_model


def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would inline this function: I think it creates an unnecessary indirection for a one-liner that names a well-understood condition.

Comment thread trl/trainer/sft_trainer.py Outdated


def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool:
return getattr(dataset, "format", {}).get("type") == "custom"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard is defensive programming and should be removed.

Comment thread tests/test_sft_trainer.py Outdated
return batch

dataset = dataset.with_transform(add_suffix)
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", use_cpu=True, bf16=False)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guards for use_cpu and bf16 are not appropriate.

Comment thread trl/trainer/sft_trainer.py Outdated


def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool:
return getattr(dataset, "format", {}).get("type") == "custom"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no public API path that puts a custom transform on an IterableDataset.

I think this would need an isinstance(dataset, Dataset) guard to make this explicit and close the concern permanently.

@he-yufeng

Copy link
Copy Markdown
Author

Addressed the review in cee507b9:

  • inlined the custom-transform check
  • limited it explicitly to datasets.Dataset; streaming IterableDataset is unaffected
  • removed the defensive getattr / .get path
  • removed the test's use_cpu and bf16 overrides
  • rebased onto current main

Validation:

python -m ruff check trl/trainer/sft_trainer.py tests/test_sft_trainer.py
python -m ruff format --check trl/trainer/sft_trainer.py tests/test_sft_trainer.py
git diff --check

All passed. The focused pytest now reaches the test but this local Windows machine cannot construct the repository-default SFTConfig because it has no bf16-capable GPU. I left the device-specific overrides removed as requested and will rely on the repository CI environment for that test.

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the suggested changes. Just one additional suggestion below.

Additionally, would you mind not force-pushing to a public branch? That makes the review process more difficult because it it not possible to check only the modification to the previous reviewed PR code (by reading only the specific lines have been modified), and instead forces to review the entire PR.

)

# If the dataset is already preprocessed (tokenized), skip the processing steps.
column_names = get_dataset_column_names(dataset)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error message is misleading: it says "provide already tokenized examples", which implies the user must pre-tokenize outside the transform. The correct pattern (transform that augments AND tokenizes) should be stated explicitly.

The second suggested workaround ("materialize the transform with Dataset.map() before constructing the trainer") is correct only for deterministic transforms; for the reported use case (random augmentation) it defeats the point and should be qualified or dropped.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I updated the message in 70cc219f so it no longer implies users must pre-tokenize outside the transform.

It now says to use skip_prepare_dataset=True and make the transform return trainer-ready examples, including tokenized fields. The Dataset.map() workaround is now qualified to deterministic transforms only.

Validated locally:

python -m py_compile trl\trainer\sft_trainer.py tests\test_sft_trainer.py
python -m ruff check trl\trainer\sft_trainer.py tests\test_sft_trainer.py
python -m ruff format --check trl\trainer\sft_trainer.py tests\test_sft_trainer.py
git diff --check origin/main..HEAD

The focused pytest still fails before reaching this assertion on my Windows machine because the repository-default SFTConfig requires bf16 GPU support here. I kept the previous review request and did not reintroduce use_cpu / bf16 test overrides.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for your contribution and for addressing all suggestions. It looks good to me.

@albertvillanova albertvillanova merged commit f92a846 into huggingface:main Jun 16, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SFTTrainer silently breaks datasets that use Dataset.with_transform

2 participants