-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Closed
Labels
new-modelRequests to new modelsRequests to new modelsstaleOver 90 days of inactivityOver 90 days of inactivity
Description
The model to consider.
https://huggingface.co/xai-org/grok-1
With int8 quantization, this model can be hosted on 8 GPUs with 80GB memory, a node of H100 or A100. After a high-level look at the code, I am seeing xai has the model architecture implemented via JAX and its code couples model architecture and implementation details such as int8 quantization and sharing across GPUs.
I saw a twitter post about the tricky implementation differences in Gemma's implementations. So, I wonder if someone familiar with JAX is planning to port it to PyTorch and validate, so that it can be integrate with vLLM with additional optimization for MOE architecture.
The closest model vllm already supports.
Mixtral 8x7B.
What's your difficulty of supporting the model you want?
- its source code is in JAX, instead of PyTorch
- It requires quantization; otherwise, it won't work on most GPUs, including H100/A100. Here, I assume cpu offloading is not of considerations to avoid notable impact on efficiency
- Its MOE component require additional optimization for inference efficiency
simon-mo, ericzhou571, rkooo567, ywang96, LagPixelLOL and 3 more
Metadata
Metadata
Assignees
Labels
new-modelRequests to new modelsRequests to new modelsstaleOver 90 days of inactivityOver 90 days of inactivity