@@ -230,8 +230,8 @@ class MiniMaxM2ModelImpl : public torch::nn::Module {
230230 const ModelInputParams& input_params) {
231231 ModelInputParams modified_input_params = input_params;
232232 torch::Tensor h;
233- if (input_params.input_embedding .defined ()) {
234- h = input_params.input_embedding ;
233+ if (input_params.embedding . input_embedding .defined ()) {
234+ h = input_params.embedding . input_embedding ;
235235 } else if (tokens.numel () == 0 ) {
236236 h = torch::empty ({0 , hidden_size_}, embed_tokens_->weight ().options ());
237237 } else {
@@ -289,22 +289,22 @@ class MiniMaxM2ModelImpl : public torch::nn::Module {
289289 layer::AttentionMetadata get_attention_metadata (
290290 const ModelInputParams& params,
291291 const torch::Tensor& h) {
292- if (params.q_max_seq_len == 0 ) {
292+ if (params.meta . q_max_seq_len == 0 ) {
293293 return layer::AttentionMetadataBuilder::build (params, enable_mla_);
294294 }
295295
296- max_seq_len_ = std::max (params.kv_max_seq_len , max_seq_len_);
296+ max_seq_len_ = std::max (params.meta . kv_max_seq_len , max_seq_len_);
297297 torch::Tensor attn_mask;
298298 if (FLAGS_enable_chunked_prefill) {
299- const int32_t max_kv_seq = params.kv_max_seq_len ;
300- const int32_t num_sequences = params.num_sequences ;
299+ const int32_t max_kv_seq = params.meta . kv_max_seq_len ;
300+ const int32_t num_sequences = params.meta . num_sequences ;
301301 if (num_sequences > 0 ) {
302302 std::vector<torch::Tensor> req_mask_vec;
303303 req_mask_vec.reserve (num_sequences);
304304 for (int32_t j = 0 ; j < num_sequences; ++j) {
305305 req_mask_vec.emplace_back (
306- attn_mask_.gen_append_mask (params.q_seq_lens_vec [j],
307- params.kv_seq_lens_vec [j],
306+ attn_mask_.gen_append_mask (params.attention . host . q_seq_lens [j],
307+ params.attention . host . kv_seq_lens [j],
308308 max_kv_seq,
309309 h.dtype ().toScalarType (),
310310 h.device ()));
0 commit comments