diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index a2a625d1e74..e74b3f15f94 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -194,9 +194,9 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"], dataset["prompt"]) - self.assertListEqual(tokenized_dataset["completion"], dataset["completion"]) - self.assertListEqual(tokenized_dataset["label"], dataset["label"]) + self.assertListEqual(tokenized_dataset["prompt"][:], dataset["prompt"][:]) + self.assertListEqual(tokenized_dataset["completion"][:], dataset["completion"][:]) + self.assertListEqual(tokenized_dataset["label"][:], dataset["label"][:]) self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) @@ -212,9 +212,9 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) - self.assertListEqual(processed_dataset["prompt"], dataset["prompt"]) - self.assertListEqual(processed_dataset["completion"], dataset["completion"]) - self.assertListEqual(processed_dataset["label"], dataset["label"]) + self.assertListEqual(processed_dataset["prompt"][:], dataset["prompt"][:]) + self.assertListEqual(processed_dataset["completion"][:], dataset["completion"][:]) + self.assertListEqual(processed_dataset["label"][:], dataset["label"][:]) self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) self.assertListEqual( diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 140467fc87b..54fe23ee298 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -153,9 +153,9 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"]) - self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"]) - self.assertListEqual(tokenized_dataset["label"], train_dataset["label"]) + self.assertListEqual(tokenized_dataset["prompt"][:], train_dataset["prompt"][:]) + self.assertListEqual(tokenized_dataset["completion"][:], train_dataset["completion"][:]) + self.assertListEqual(tokenized_dataset["label"][:], train_dataset["label"][:]) self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) @@ -193,9 +193,9 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) - self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"]) - self.assertListEqual(processed_dataset["completion"], train_dataset["completion"]) - self.assertListEqual(processed_dataset["label"], train_dataset["label"]) + self.assertListEqual(processed_dataset["prompt"][:], train_dataset["prompt"][:]) + self.assertListEqual(processed_dataset["completion"][:], train_dataset["completion"][:]) + self.assertListEqual(processed_dataset["label"][:], train_dataset["label"][:]) self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) self.assertListEqual( diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index 853d126429b..36c866c9961 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -357,6 +357,13 @@ def step( "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." ) + # Convert Column to list if not already + input_ids = input_ids[:] if input_ids is not None else None + attention_mask = attention_mask[:] if attention_mask is not None else None + labels = labels[:] if labels is not None else None + texts = texts[:] if texts is not None else None + texts_labels = texts_labels[:] if texts_labels is not None else None + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( input_ids, attention_mask, labels, texts, texts_labels )