Skip to content

[New Model]: Porting Chatterbox TTS to VLLM #21989

@randombk

Description

@randombk

The model to consider.

Community/unaffiliated effort to port Chatterbox to VLLM.

The closest model vllm already supports.

vllm.model_executor.models.llama.LlamaModel

What's your difficulty of supporting the model you want?

Hi folks,

I'm looking for advice on implementing a non-standard model architecture (https://github.com/randombk/chatterbox-vllm). It's a Text-to-Speech model applying intermediate fusion multimodal conditioning to a 0.5B parameter Llama model to generate speech tokens.

Right now, via a set of horrendous hacks, I have the core running in VLLM for unbatched requests. For batched requests, there are a few API limitations that are causing difficulty. They're solvable via more hacks, but I'd like to see if there are more idiomatic approaches/alternatives I've missed. I'm also willing to help extend/improve VLLM if folks can point me in the right direction.

The primary relevant file is at https://github.com/randombk/chatterbox-vllm/blob/master/src/chatterbox_vllm/models/t3/t3.py.

For dependency reasons I'm currently using VLLM 0.9.2

There are two blockers right now.

1: Multimodal embedding

The primary blocker lies with the multimodal get_input_embeddings method, and how it combines prefill and decode tokens. The model applies different embedding logic for prefill/decode, so I'd ideally like to know which tokens are which.

For example, given a batch of three requests:

VLLM schedules the prefill as two calls to get_input_embeddings: one containing just the first query (<prefill 1>), and other a combination of <decode 1, prefill 2, prefill 3>.

Specifically, the first call is passed:

input_ids:
[695, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 301,  45,
   2, 127,   2,  25,  54,  33,  50,  52,   2,  51,   2,  14,   2, 126,
 115,   2,  58,   2,  42,   2, 279,  21,  48, 114, 165,  37,   2, 296,
 296, 295,   2, 115, 126,  25,   2,  31,  97,  27,  52,   2,  47,   2,
 298, 288, 288, 289,   9,   0,   0]

multimodal_embeddings: [torch.Size([34, 1024])]

The prompt always starts with token 695 and ends with two 0 tokens. The embedding logic for this is separate from the logic during decoding.

The second call to get_input_embeddings looks like this:

input_ids:
[2993,  695,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  296,   94,    2,
  54,    2,   14,    2,   66,   29,   59,  186,    2,   29,   31,   65,
  29,   33,    2,   51,    2,   33,  218,    2,   42,    2,   15,   48,
  71,   52,    2,  106,   29,   64,   26,   86,  158,    9,    0,    0,
 695,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,  255,
 255,  255,  255,  255,  255,  255,  255,  255,  277,   27,   17,    2,
  92,   18,    2,   54,    2,   14,    2,   40,   98,   17,    2,   29,
  31,   65,   29,   33,    9,    2,  285,   33,    4,   32,    2,   14,
   2,   15,   60,    2,   25,  222,   44,    2,   40,   43,    2,   42,
   2,   19,   98,   63,    2,  110,    7,    2,  128,    2,  149,    2,
 185,    2,   26,   34,   71,    9,    0,    0]

multimodal_embeddings: [torch.Size([34, 1024]), torch.Size([34, 1024])]

We see that input_ids is a concatenation of the first decoded token from query 1 (2993), the prefill for query 2, and the prefill for query 3. The prefill of each query is comprised of a 34-token conditioning plus a variable-length sequence of text tokens, both of which are embedded different from the decoded tokens.

In an ideal world, I would like to be provided offsets/indices for which input_ids range corresponds to what. Is this possible?

2: Positional embedding

In addition to the normal Llama positional embedding, the model requires additional learned positional embeddings that count from the first decoded token (not the first token in the prompt). Is there a way to obtain this? I've experimented with ways of passing around the prefill length to apply as an offset to the provided positions vector in the forward pass, but the way I'm doing it right does not work with batched queries.

Note: this is a personal project and not connected to my employer.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions