Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ jobs:
key: v0.4-code_quality-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: black --check examples tests src utils
- run: black --check --preview examples tests src utils
- run: isort --check-only examples tests src utils
- run: python utils/custom_init_isort.py --check_only
- run: flake8 examples tests src utils
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \
black $(modified_py_files); \
black --preview $(modified_py_files); \
isort $(modified_py_files); \
flake8 $(modified_py_files); \
else \
Expand Down Expand Up @@ -45,7 +45,7 @@ repo-consistency:
# this target runs checks on all files

quality:
black --check $(check_dirs)
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only
flake8 $(check_dirs)
Expand All @@ -60,7 +60,7 @@ extra_style_checks:
# this target runs checks on all files and potentially modifies some of them

style:
black $(check_dirs)
black --preview $(check_dirs)
isort $(check_dirs)
${MAKE} autogenerate_code
${MAKE} extra_style_checks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@ class ModelArguments:
)
encoder_model_name_or_path: str = field(
metadata={
"help": "The encoder model checkpoint for weights initialization."
"Don't set if you want to train an encoder model from scratch."
"help": (
"The encoder model checkpoint for weights initialization."
"Don't set if you want to train an encoder model from scratch."
)
},
)
decoder_model_name_or_path: str = field(
metadata={
"help": "The decoder model checkpoint for weights initialization."
"Don't set if you want to train a decoder model from scratch."
"help": (
"The decoder model checkpoint for weights initialization."
"Don't set if you want to train a decoder model from scratch."
)
},
)
encoder_config_name: Optional[str] = field(
Expand Down
60 changes: 40 additions & 20 deletions examples/flax/image-captioning/run_image_captioning_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
"help": (
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
)
},
)

Expand Down Expand Up @@ -222,38 +227,48 @@ class DataTrainingArguments:
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
"during evaluation."
"help": (
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
"during evaluation."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
Expand All @@ -266,8 +281,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
"help": (
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
)
},
)
overwrite_cache: bool = field(
Expand Down Expand Up @@ -623,7 +640,7 @@ def preprocess_fn(examples, max_target_length, check_image=True):
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
if training_args.block_size % train_batch_size > 0 or training_args.block_size % eval_batch_size > 0:
raise ValueError(
f"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
f"Got {training_args.block_size}, {train_batch_size} and {eval_batch_size} respectively instead."
)

Expand Down Expand Up @@ -1136,7 +1153,7 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
)

# train
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):

cur_step += 1
batch = next(train_batches)
Expand All @@ -1150,7 +1167,10 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:

_train_metric = unreplicate(train_metric)
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
f" Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
)
epochs.desc = desc
epochs.write(desc)

Expand Down
47 changes: 32 additions & 15 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
Expand All @@ -162,14 +163,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
"help": (
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
)
},
)

Expand All @@ -194,15 +200,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
Expand All @@ -217,9 +227,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
"help": "Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
Expand Down Expand Up @@ -505,7 +517,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return output

Expand Down Expand Up @@ -735,7 +748,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)

epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
f" {train_metric['learning_rate'].mean()})"
)

train_metrics = []
Expand All @@ -762,7 +776,10 @@ def eval_step(params, batch):
eval_metrics["perplexity"] = float("inf")

# Print metrics and update progress bar
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
desc = (
f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
f" {eval_metrics['perplexity']})"
)
epochs.write(desc)
epochs.desc = desc

Expand Down
31 changes: 21 additions & 10 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
Expand All @@ -160,14 +161,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
"help": (
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
)
},
)

Expand Down Expand Up @@ -209,8 +215,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
)
},
)
preprocessing_num_workers: Optional[int] = field(
Expand All @@ -223,8 +231,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
"help": (
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
)
},
)
line_by_line: bool = field(
Expand Down Expand Up @@ -764,7 +774,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)

epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)

train_metrics = []
Expand Down
Loading