Skip to content

Commit 6d14210

Browse files
authored
refactor: adapt minimax-m2 model to restructured ModelInputParams fields. (#1505)
1 parent 84a92d6 commit 6d14210

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

xllm/models/llm/npu/minimax_m2.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ class MiniMaxM2ModelImpl : public torch::nn::Module {
230230
const ModelInputParams& input_params) {
231231
ModelInputParams modified_input_params = input_params;
232232
torch::Tensor h;
233-
if (input_params.input_embedding.defined()) {
234-
h = input_params.input_embedding;
233+
if (input_params.embedding.input_embedding.defined()) {
234+
h = input_params.embedding.input_embedding;
235235
} else if (tokens.numel() == 0) {
236236
h = torch::empty({0, hidden_size_}, embed_tokens_->weight().options());
237237
} else {
@@ -289,22 +289,22 @@ class MiniMaxM2ModelImpl : public torch::nn::Module {
289289
layer::AttentionMetadata get_attention_metadata(
290290
const ModelInputParams& params,
291291
const torch::Tensor& h) {
292-
if (params.q_max_seq_len == 0) {
292+
if (params.meta.q_max_seq_len == 0) {
293293
return layer::AttentionMetadataBuilder::build(params, enable_mla_);
294294
}
295295

296-
max_seq_len_ = std::max(params.kv_max_seq_len, max_seq_len_);
296+
max_seq_len_ = std::max(params.meta.kv_max_seq_len, max_seq_len_);
297297
torch::Tensor attn_mask;
298298
if (FLAGS_enable_chunked_prefill) {
299-
const int32_t max_kv_seq = params.kv_max_seq_len;
300-
const int32_t num_sequences = params.num_sequences;
299+
const int32_t max_kv_seq = params.meta.kv_max_seq_len;
300+
const int32_t num_sequences = params.meta.num_sequences;
301301
if (num_sequences > 0) {
302302
std::vector<torch::Tensor> req_mask_vec;
303303
req_mask_vec.reserve(num_sequences);
304304
for (int32_t j = 0; j < num_sequences; ++j) {
305305
req_mask_vec.emplace_back(
306-
attn_mask_.gen_append_mask(params.q_seq_lens_vec[j],
307-
params.kv_seq_lens_vec[j],
306+
attn_mask_.gen_append_mask(params.attention.host.q_seq_lens[j],
307+
params.attention.host.kv_seq_lens[j],
308308
max_kv_seq,
309309
h.dtype().toScalarType(),
310310
h.device()));

0 commit comments

Comments
 (0)