@@ -2568,22 +2568,22 @@ def trtllm_mxint4_block_scale_moe(
25682568 Args:
25692569 routing_logits (torch.Tensor): shape [seq_len, num_experts]
25702570 Input tensor of routing logits. Supports float32, bfloat16.
2571- hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size ]
2572- Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
2571+ hidden_states (torch.Tensor): shape [seq_len, hidden_size]
2572+ Tensor of input hidden states. Supports bfloat16.
25732573 gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
2574- Tensor of FC1 weights. Dtype must be uint8 (packed fp4 )
2575- gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16) ]
2576- Scale tensor of FC1 weights. Dtype must be float8 .
2574+ Tensor of FC1 weights. Dtype must be uint8 (packed mxint4 )
2575+ gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 32 ]
2576+ Scale tensor of FC1 weights. Dtype must be bfloat16 .
25772577 gemm1_alpha (Optional[torch.Tensor]): shape [num_experts]
25782578 Tensor of swiglu alpha. Dtype is float32.
25792579 gemm1_beta (Optional[torch.Tensor]): shape [num_experts]
25802580 Tensor of swiglu beta. Dtype is float32.
25812581 gemm1_clamp_limit (Optional[torch.Tensor]): shape [num_experts]
25822582 Tensor of swiglu clamp limit. Dtype is float32.
25832583 gemm2_weights (torch.Tensor): shape [num_experts, hidden_size, intermediate_size]
2584- Tensor of FC2 weights. Dtype must be uint8 (packed fp4 )
2585- gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // (32 if mxfp4 else 16) ]
2586- Scale tensor of FC2 weights. Dtype must be float8 .
2584+ Tensor of FC2 weights. Dtype must be uint8 (packed mxint4 )
2585+ gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // 32 ]
2586+ Scale tensor of FC2 weights. Dtype must be bfloat16 .
25872587 num_experts (int): Total number of experts
25882588 top_k (int): Number of experts to route to per token
25892589 n_group (Optional[int]): Number of expert groups (can be None for some routing methods)
0 commit comments