Skip to content

Commit f13b100

Browse files
authored
Xcodec fix (#42095)
* nit on dac! * fix * not for this pr * make style
1 parent 45d273d commit f13b100

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

src/transformers/audio_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,18 @@ def load_audio_as(
219219
raise ValueError(f"Error loading audio: {e}")
220220

221221

222+
def conv1d_output_length(module: "torch.nn.Conv1d", input_length: int) -> int:
223+
"""
224+
Computes the output length of a 1D convolution layer according to torch's documentation:
225+
https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
226+
"""
227+
return int(
228+
(input_length + 2 * module.padding[0] - module.dilation[0] * (module.kernel_size[0] - 1) - 1)
229+
/ module.stride[0]
230+
+ 1
231+
)
232+
233+
222234
def is_valid_audio(audio):
223235
return is_numpy_array(audio) or is_torch_tensor(audio)
224236

src/transformers/models/xcodec/modeling_xcodec.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import math
1818
from dataclasses import dataclass
19+
from functools import lru_cache
1920
from typing import Optional, Union
2021

2122
import torch
2223
import torch.nn as nn
2324
import torch.nn.functional as F
2425

2526
from ... import initialization as init
27+
from ...audio_utils import conv1d_output_length
2628
from ...modeling_utils import PreTrainedAudioTokenizerBase
2729
from ...utils import ModelOutput, auto_docstring
2830
from ..auto import AutoModel
@@ -396,6 +398,40 @@ def remove_weight_norm(self):
396398
if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
397399
torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)
398400

401+
@lru_cache
402+
def _get_conv1d_layers(self, module):
403+
"""
404+
Recursively iterate to fetch all Conv1d layers.
405+
"""
406+
407+
def get_conv1d_layers_recursive(module: nn.Module):
408+
params_list = []
409+
410+
if isinstance(module, nn.Conv1d):
411+
params_list.append(module)
412+
413+
# Recursively check all child modules
414+
for child in module.children():
415+
params_list.extend(get_conv1d_layers_recursive(child))
416+
417+
return params_list
418+
419+
return tuple(get_conv1d_layers_recursive(module))
420+
421+
def _get_conv1d_output_lengths(self, input_length, module=None):
422+
"""
423+
For a given module, compute the output length that would be obtained after all Conv1d layers.
424+
"""
425+
if module is None:
426+
module = self
427+
428+
conv1d_layers = self._get_conv1d_layers(module)
429+
430+
for layer in conv1d_layers:
431+
input_length = conv1d_output_length(layer, input_length)
432+
433+
return input_length
434+
399435

400436
@auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
401437
class XcodecModel(XcodecPreTrainedModel):
@@ -476,11 +512,13 @@ def encode(
476512

477513
e_semantic_input = self._extract_semantic_features(input_values).detach()
478514
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
479-
e_acoustic = self.acoustic_encoder(input_values)
480515

481-
if e_acoustic.shape[2] != e_semantic.shape[2]:
482-
# make sure they line up if frames don't match
483-
e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1))
516+
# orignal codebase infer to get the output length, but we can directly infer it
517+
# from the model and know wether we should pad
518+
if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]:
519+
e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad)))
520+
else:
521+
e_acoustic = self.acoustic_encoder(input_values)
484522

485523
embeddings = torch.cat([e_acoustic, e_semantic], dim=1)
486524
embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)

0 commit comments

Comments
 (0)