Skip to content

Commit 2ad06af

Browse files
KennethEnevoldsenKennethEnevoldsen
authored andcommitted
fixed bug in run_mlm_flax_stream.py (huggingface#17203)
* fixed bug run_mlm_flax_stream.py Fixed bug caused by an update to tokenizer keys introduced in recent transformers versions (between `4.6.2` and `4.18.0`) where additional keys were introduced to the tokenizer output. * Update run_mlm_flax_stream.py * adding missing paranthesis * formatted to black * remove cols from dataset instead * reformat to black * moved rem. columns to map * formatted to black Co-authored-by: KennethEnevoldsen <[email protected]>
1 parent 3182529 commit 2ad06af

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
288288
tokenized_samples = next(train_iterator)
289289
i += len(tokenized_samples["input_ids"])
290290

291-
# concatenate tokenized samples to list
292-
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
291+
# concatenate tokenized samples to list (excluding "id" and "text")
292+
samples = {
293+
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
294+
}
293295

294296
# Concatenated tokens are split to lists of length `max_seq_length`.
295297
# Note that remainedr of % max_seq_length are thrown away.
@@ -407,10 +409,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
407409
def tokenize_function(examples):
408410
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
409411

410-
tokenized_datasets = dataset.map(
411-
tokenize_function,
412-
batched=True,
413-
)
412+
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
414413

415414
shuffle_seed = training_args.seed
416415
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)

0 commit comments

Comments
 (0)