Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,18 @@ def forward(
def bad_word_processor(self, logits: torch.Tensor) -> torch.Tensor:
# suppress token IDs unsupported by token2wav
if self.suppress_start_id and self.suppress_start_id < logits.size(-1):
logits[..., self.suppress_start_id : logits.size(-1)] = -1e9
# skip the end token id.
if hasattr(self.config, "tts_codec_end_token_id"):
end_id = int(getattr(self.config, "tts_codec_end_token_id"))
if self.suppress_start_id == end_id:
logits[..., end_id + 1 : logits.size(-1)] = -1e9
elif self.suppress_start_id < end_id:
logits[..., self.suppress_start_id : end_id] = -1e9
logits[..., end_id + 1 : logits.size(-1)] = -1e9
else:
logits[..., self.suppress_start_id : logits.size(-1)] = -1e9
else:
raise ValueError("config must have tts_codec_end_token_id attribute")

if hasattr(self.config, "tts_codec_start_token_id"):
bos_id = int(getattr(self.config, "tts_codec_start_token_id"))
Expand Down