Skip to content

[New Model]: Request to support xai-org/grok-1 (314B parameters with MOE architecture) #3472

@ai-jz

Description

@ai-jz

The model to consider.

https://huggingface.co/xai-org/grok-1

With int8 quantization, this model can be hosted on 8 GPUs with 80GB memory, a node of H100 or A100. After a high-level look at the code, I am seeing xai has the model architecture implemented via JAX and its code couples model architecture and implementation details such as int8 quantization and sharing across GPUs.

I saw a twitter post about the tricky implementation differences in Gemma's implementations. So, I wonder if someone familiar with JAX is planning to port it to PyTorch and validate, so that it can be integrate with vLLM with additional optimization for MOE architecture.

The closest model vllm already supports.

Mixtral 8x7B.

What's your difficulty of supporting the model you want?

  • its source code is in JAX, instead of PyTorch
  • It requires quantization; otherwise, it won't work on most GPUs, including H100/A100. Here, I assume cpu offloading is not of considerations to avoid notable impact on efficiency
  • Its MOE component require additional optimization for inference efficiency

Metadata

Metadata

Assignees

No one assigned

    Labels

    new-modelRequests to new modelsstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions