fix(sft): reject transformed datasets during preparation#6054
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 1c9f468. Configure here.
|
|
||
|
|
||
| def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool: | ||
| return getattr(dataset, "format", {}).get("type") == "custom" |
There was a problem hiding this comment.
Iterable transform guard misses stream
Medium Severity
_dataset_has_custom_transform checks dataset.format via getattr, but elsewhere in TRL, IterableDataset formatting is read from _formatting.format_type in _get_dataset_format. If streaming datasets lack a format dict with type == "custom", _prepare_dataset can still run map() on a lazy custom transform.
Reviewed by Cursor Bugbot for commit 1c9f468. Configure here.
There was a problem hiding this comment.
There is no public API path that puts a custom transform on an IterableDataset.
I think this would need an isinstance(dataset, Dataset) guard to make this explicit and close the concern permanently.
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks for addressing the underlying issue. Below my concerns and suggested changes.
| from peft import PeftConfig, PeftModel, PeftType, get_peft_model | ||
|
|
||
|
|
||
| def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool: |
There was a problem hiding this comment.
I would inline this function: I think it creates an unnecessary indirection for a one-liner that names a well-understood condition.
|
|
||
|
|
||
| def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool: | ||
| return getattr(dataset, "format", {}).get("type") == "custom" |
There was a problem hiding this comment.
This guard is defensive programming and should be removed.
| return batch | ||
|
|
||
| dataset = dataset.with_transform(add_suffix) | ||
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", use_cpu=True, bf16=False) |
There was a problem hiding this comment.
The guards for use_cpu and bf16 are not appropriate.
|
|
||
|
|
||
| def _dataset_has_custom_transform(dataset: Dataset | IterableDataset) -> bool: | ||
| return getattr(dataset, "format", {}).get("type") == "custom" |
There was a problem hiding this comment.
There is no public API path that puts a custom transform on an IterableDataset.
I think this would need an isinstance(dataset, Dataset) guard to make this explicit and close the concern permanently.
|
Addressed the review in
Validation: All passed. The focused pytest now reaches the test but this local Windows machine cannot construct the repository-default |
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks for addressing the suggested changes. Just one additional suggestion below.
Additionally, would you mind not force-pushing to a public branch? That makes the review process more difficult because it it not possible to check only the modification to the previous reviewed PR code (by reading only the specific lines have been modified), and instead forces to review the entire PR.
| ) | ||
|
|
||
| # If the dataset is already preprocessed (tokenized), skip the processing steps. | ||
| column_names = get_dataset_column_names(dataset) |
There was a problem hiding this comment.
I think the error message is misleading: it says "provide already tokenized examples", which implies the user must pre-tokenize outside the transform. The correct pattern (transform that augments AND tokenizes) should be stated explicitly.
The second suggested workaround ("materialize the transform with Dataset.map() before constructing the trainer") is correct only for deterministic transforms; for the reported use case (random augmentation) it defeats the point and should be qualified or dropped.
There was a problem hiding this comment.
Agreed. I updated the message in 70cc219f so it no longer implies users must pre-tokenize outside the transform.
It now says to use skip_prepare_dataset=True and make the transform return trainer-ready examples, including tokenized fields. The Dataset.map() workaround is now qualified to deterministic transforms only.
Validated locally:
python -m py_compile trl\trainer\sft_trainer.py tests\test_sft_trainer.py
python -m ruff check trl\trainer\sft_trainer.py tests\test_sft_trainer.py
python -m ruff format --check trl\trainer\sft_trainer.py tests\test_sft_trainer.py
git diff --check origin/main..HEAD
The focused pytest still fails before reaching this assertion on my Windows machine because the repository-default SFTConfig requires bf16 GPU support here. I kept the previous review request and did not reintroduce use_cpu / bf16 test overrides.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks again for your contribution and for addressing all suggestions. It looks good to me.


Fixes #6039.
SFTTrainer currently lets
Dataset.with_transform()datasets enter the automatic preparation pipeline. That path callsDataset.map()for EOS insertion and tokenization, andmap()reads rows through the active custom transform. For stateful or random transforms, that can bake one transform realization into the prepared Arrow columns while later accesses still run a different transform.This PR adds a guard before SFT dataset preparation: if the dataset has a custom transform, fail with a clear error explaining why automatic preparation is unsafe and pointing users to either
dataset_kwargs={"skip_prepare_dataset": True}with already-tokenized examples, or materializing the transform withDataset.map()before constructing the trainer.I kept this as the conservative fix from the issue. It does not attempt lazy transform composition or packing support in this PR.
Before submitting
SFTTrainersilently breaks datasets that useDataset.with_transform#6039AI writing disclosure
Tests
python -m ruff check trl\\trainer\\sft_trainer.py tests\\test_sft_trainer.py.\\.venv\\Scripts\\python.exe -m py_compile trl\\trainer\\sft_trainer.py tests\\test_sft_trainer.py.\\.venv\\Scripts\\python.exe -m pytest tests\\test_sft_trainer.py::TestSFTTrainer::test_dataset_with_transform_requires_skip_prepare_dataset -qOne nearby skip-prepare test was also attempted in this local Windows venv, but it fails before reaching trainer logic because the current CPU-only environment rejects the default bf16 settings. I did not change existing tests to hide that environment issue.
Note
Low Risk
Small, early validation in the dataset prep path; only blocks an unsafe pattern and does not change successful training flows.
Overview
SFTTrainer now refuses datasets created with
Dataset.with_transform()when automatic dataset preparation runs. Preparation usesDataset.map(), which materializes rows through the lazy transform and can freeze one random or stateful augmentation into Arrow columns while later reads still apply a different transform.At the start of
_prepare_dataset, the trainer checks for a custom dataset format and raisesValueErrorwith guidance to usedataset_kwargs={'skip_prepare_dataset': True}with trainer-ready (e.g. tokenized) examples from the transform, or tomap()deterministic transforms before building the trainer.A regression test asserts that
SFTTrainerconstruction fails on awith_transformdataset with the expected message.Reviewed by Cursor Bugbot for commit 27def35. Bugbot is set up for automated code reviews on this repo. Configure here.