Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions examples/speech-recognition/run_speech_recognition_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import datasets
import evaluate
import librosa
import numpy as np
import soundfile as sf
import torch
import transformers
Expand Down Expand Up @@ -270,18 +271,19 @@ def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) ->
# different padding methods
model_input_name = self.processor.model_input_names[0]
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]

batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

if self.forward_attention_mask:
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])

kwargs = {}
if self.label_features_max_length is not None:
kwargs["padding"] = "max_length"
kwargs["max_length"] = self.label_features_max_length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", **kwargs)
texts = [feature["labels"] for feature in features]
labels_batch = self.processor.tokenizer(
texts,
padding="max_length" if self.label_features_max_length else True,
max_length=self.label_features_max_length,
return_tensors="pt",
truncation=True,
)

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
Expand Down Expand Up @@ -310,6 +312,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

# Ensure DDP doesn't break with gradient checkpointing
if training_args.gradient_checkpointing:
training_args.ddp_find_unused_parameters = True

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args)
Expand All @@ -335,6 +341,10 @@ def main():
token=model_args.token,
)

if training_args.gradient_checkpointing and getattr(gaudi_config, "use_hpu_graphs_for_inference", False):
logger.warning("Disabling HPU graphs for inference during training because gradient checkpointing is enabled.")
gaudi_config.use_hpu_graphs_for_inference = False

# Log on each process the small summary:
mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
logger.warning(
Expand Down Expand Up @@ -528,21 +538,21 @@ def prepare_dataset(batch):
if isinstance(path, str):
try:
wav, sr = sf.read(path, dtype="float32", always_2d=False)
except Exception:
try:
wav, sr = librosa.load(path, sr=None, mono=False)
except Exception:
wav, sr = None, None
except Exception as e:
logger.warning(
f"SoundFile failed to read {path} ({type(e).__name__}: {e}). Falling back to torchaudio."
)
wav, sr = None, None

# Fallback: load from bytes if available
if wav is None:
raw = sample.get("bytes", None)
raw = sample.get("bytes")
if raw is None:
raise RuntimeError(f"Cannot open audio sample {sample}")
wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)

# Convert to mono
if getattr(wav, "ndim", 1) > 1:
if isinstance(wav, np.ndarray) and wav.ndim > 1:
wav = wav.mean(axis=1)

# Resample if necessary
Expand All @@ -559,7 +569,7 @@ def prepare_dataset(batch):
text = batch[text_column_name]
if do_lower_case and isinstance(text, str):
text = text.lower()
batch["labels"] = tokenizer(text).input_ids
batch["labels"] = text

return batch

Expand Down
23 changes: 13 additions & 10 deletions optimum/habana/transformers/gradient_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,19 @@ def checkpoint(
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
if use_reentrant is None:
warn0(
"torch.utils.checkpoint: the use_reentrant parameter should be "
"passed explicitly. In version 2.5 we will raise an exception "
"if use_reentrant is not passed. use_reentrant=False is "
"recommended, but if you need to preserve the current default "
"behavior, you can pass use_reentrant=True. Refer to docs for more "
"details on the differences between the two variants."
)
use_reentrant = True
if use_reentrant is None or use_reentrant:
# Transformers>=4.55 + PyTorch>=2.2 require non-reentrant checkpointing on HPU
# Reentrant mode conflicts with DDP on HPU (duplicate backward hooks)
use_reentrant = False
if not hasattr(checkpoint, "_warned_once"):
checkpoint._warned_once = True
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
warnings.warn(
"Reentrant gradient checkpointing has been disabled (use_reentrant=False) "
"because it conflicts with DDP and HPU graphs on Gaudi. "
"This avoids duplicated backward hooks and ensures stable training on HPU.",
UserWarning,
)

# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop("preserve_rng_state", True)
Expand Down
Loading
Loading