|
16 | 16 |
|
17 | 17 | import math |
18 | 18 | from dataclasses import dataclass |
| 19 | +from functools import lru_cache |
19 | 20 | from typing import Optional, Union |
20 | 21 |
|
21 | 22 | import torch |
22 | 23 | import torch.nn as nn |
23 | 24 | import torch.nn.functional as F |
24 | 25 |
|
25 | 26 | from ... import initialization as init |
| 27 | +from ...audio_utils import conv1d_output_length |
26 | 28 | from ...modeling_utils import PreTrainedAudioTokenizerBase |
27 | 29 | from ...utils import ModelOutput, auto_docstring |
28 | 30 | from ..auto import AutoModel |
@@ -396,6 +398,40 @@ def remove_weight_norm(self): |
396 | 398 | if hasattr(m, "parametrizations") and "weight" in m.parametrizations: |
397 | 399 | torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) |
398 | 400 |
|
| 401 | + @lru_cache |
| 402 | + def _get_conv1d_layers(self, module): |
| 403 | + """ |
| 404 | + Recursively iterate to fetch all Conv1d layers. |
| 405 | + """ |
| 406 | + |
| 407 | + def get_conv1d_layers_recursive(module: nn.Module): |
| 408 | + params_list = [] |
| 409 | + |
| 410 | + if isinstance(module, nn.Conv1d): |
| 411 | + params_list.append(module) |
| 412 | + |
| 413 | + # Recursively check all child modules |
| 414 | + for child in module.children(): |
| 415 | + params_list.extend(get_conv1d_layers_recursive(child)) |
| 416 | + |
| 417 | + return params_list |
| 418 | + |
| 419 | + return tuple(get_conv1d_layers_recursive(module)) |
| 420 | + |
| 421 | + def _get_conv1d_output_lengths(self, input_length, module=None): |
| 422 | + """ |
| 423 | + For a given module, compute the output length that would be obtained after all Conv1d layers. |
| 424 | + """ |
| 425 | + if module is None: |
| 426 | + module = self |
| 427 | + |
| 428 | + conv1d_layers = self._get_conv1d_layers(module) |
| 429 | + |
| 430 | + for layer in conv1d_layers: |
| 431 | + input_length = conv1d_output_length(layer, input_length) |
| 432 | + |
| 433 | + return input_length |
| 434 | + |
399 | 435 |
|
400 | 436 | @auto_docstring(custom_intro="""The Xcodec neural audio codec model.""") |
401 | 437 | class XcodecModel(XcodecPreTrainedModel): |
@@ -476,11 +512,13 @@ def encode( |
476 | 512 |
|
477 | 513 | e_semantic_input = self._extract_semantic_features(input_values).detach() |
478 | 514 | e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) |
479 | | - e_acoustic = self.acoustic_encoder(input_values) |
480 | 515 |
|
481 | | - if e_acoustic.shape[2] != e_semantic.shape[2]: |
482 | | - # make sure they line up if frames don't match |
483 | | - e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1)) |
| 516 | + # orignal codebase infer to get the output length, but we can directly infer it |
| 517 | + # from the model and know wether we should pad |
| 518 | + if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]: |
| 519 | + e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad))) |
| 520 | + else: |
| 521 | + e_acoustic = self.acoustic_encoder(input_values) |
484 | 522 |
|
485 | 523 | embeddings = torch.cat([e_acoustic, e_semantic], dim=1) |
486 | 524 | embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2) |
|
0 commit comments