@@ -153,9 +153,9 @@ def test_tokenize_and_process_tokens(self):
153153 batched = True ,
154154 batch_size = 2 ,
155155 )
156- self .assertListEqual (tokenized_dataset ["prompt" ], train_dataset ["prompt" ])
157- self .assertListEqual (tokenized_dataset ["completion" ], train_dataset ["completion" ])
158- self .assertListEqual (tokenized_dataset ["label" ], train_dataset ["label" ])
156+ self .assertListEqual (tokenized_dataset ["prompt" ][:] , train_dataset ["prompt" ][: ])
157+ self .assertListEqual (tokenized_dataset ["completion" ][:] , train_dataset ["completion" ][: ])
158+ self .assertListEqual (tokenized_dataset ["label" ][:] , train_dataset ["label" ][: ])
159159 self .assertListEqual (tokenized_dataset ["prompt_input_ids" ][0 ], [46518 , 374 , 2664 , 1091 ])
160160 self .assertListEqual (tokenized_dataset ["prompt_attention_mask" ][0 ], [1 , 1 , 1 , 1 ])
161161 self .assertListEqual (tokenized_dataset ["answer_input_ids" ][0 ], [27261 , 13 ])
@@ -193,9 +193,9 @@ def test_tokenize_and_process_tokens(self):
193193 "max_prompt_length" : trainer .max_prompt_length ,
194194 }
195195 processed_dataset = tokenized_dataset .map (_process_tokens , fn_kwargs = fn_kwargs , num_proc = 2 )
196- self .assertListEqual (processed_dataset ["prompt" ], train_dataset ["prompt" ])
197- self .assertListEqual (processed_dataset ["completion" ], train_dataset ["completion" ])
198- self .assertListEqual (processed_dataset ["label" ], train_dataset ["label" ])
196+ self .assertListEqual (processed_dataset ["prompt" ][:] , train_dataset ["prompt" ][: ])
197+ self .assertListEqual (processed_dataset ["completion" ][:] , train_dataset ["completion" ][: ])
198+ self .assertListEqual (processed_dataset ["label" ][:] , train_dataset ["label" ][: ])
199199 self .assertListEqual (processed_dataset ["prompt_input_ids" ][0 ], [46518 , 374 , 2664 , 1091 ])
200200 self .assertListEqual (processed_dataset ["prompt_attention_mask" ][0 ], [1 , 1 , 1 , 1 ])
201201 self .assertListEqual (
0 commit comments