Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ cc_library(
llama_decoder_layer.h
multi_head_attention.h
qwen2_decoder_layer.h
qwen2_vision_encode_layer.h
qwen2dot5_vision_decode_layer.h
qwen3_vision_encode_layer.h
qwen3_decoder_layer.h
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cc_library(
npu_word_embedding_impl.h
npu_pos_embedding_impl.h
npu_lm_head_impl.h
npu_qwen2_vision_encoder_layer_impl.h
npu_qwen2dot5_vision_encoder_layer_impl.h
npu_qwen3_vision_encoder_layer_impl.h
npu_qwen3_moe_decoder_layer_impl.h
Expand All @@ -29,6 +30,7 @@ cc_library(
npu_word_embedding_impl.cpp
npu_pos_embedding_impl.cpp
npu_lm_head_impl.cpp
npu_qwen2_vision_encoder_layer_impl.cpp
npu_qwen2dot5_vision_encoder_layer_impl.cpp
npu_qwen3_vision_encoder_layer_impl.cpp
npu_qwen3_moe_decoder_layer_impl.cpp
Expand Down
285 changes: 285 additions & 0 deletions xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "npu_qwen2_vision_encoder_layer_impl.h"

#include <glog/logging.h>
#include <mstx/ms_tools_ext.h>

#include <iostream>
#include <map>

#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h"

namespace xllm {
namespace layer {

enum VisionEncoderLayerTensorId : int {
IN_INPUT_NORM_WEIGHT = 0,
IN_INPUT_NORM_BIAS,
IN_POST_NORM_WEIGHT,
IN_POST_NORM_BIAS,
IN_QKV_WEIGHT,
IN_QKV_BIAS,
IN_WATTENTION_OUT_WEIGHT,
IN_WATTENTION_OUT_BIAS,
IN_LINEAR_FC1_WEIGHT,
IN_LINEAR_FC1_BIAS,
IN_LINEAR_FC2_WEIGHT,
IN_LINEAR_FC2_BIAS,
IN_VISION_Q_WEIGHT,
IN_VISION_Q_BIAS,
IN_VISION_K_WEIGHT,
IN_VISION_K_BIAS,
IN_VISION_V_WEIGHT,
IN_VISION_V_BIAS
};

const uint64_t WEIGHT_COUNT_PER_LAYER = 18;

static std::vector<std::pair<int, std::string>> WEIGHT_MAPPING = {
{IN_INPUT_NORM_WEIGHT, "norm1.weight"},
{IN_INPUT_NORM_BIAS, "norm1.bias"},
{IN_POST_NORM_WEIGHT, "norm2.weight"},
{IN_POST_NORM_BIAS, "norm2.bias"},
{IN_QKV_WEIGHT, "attn.qkv.weight"},
{IN_QKV_BIAS, "attn.qkv.bias"},
{IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"},
{IN_WATTENTION_OUT_BIAS, "attn.proj.bias"},
{IN_LINEAR_FC1_WEIGHT, "mlp.fc1.weight"},
{IN_LINEAR_FC1_BIAS, "mlp.fc1.bias"},
{IN_LINEAR_FC2_WEIGHT, "mlp.fc2.weight"},
{IN_LINEAR_FC2_BIAS, "mlp.fc2.bias"}};

// {weight,dim}
static std::map<int, int> WEIGHT_SHARD = {
{IN_WATTENTION_OUT_WEIGHT, 1},
{IN_LINEAR_FC1_WEIGHT, 0},
{IN_LINEAR_FC1_BIAS, 0},
{IN_LINEAR_FC2_WEIGHT, 1},
};

void NpuQwen2VisionEncoderLayerImpl::param_from_args(
atb_speed::qwen::VisionEncoderLayerParam& param,
const ModelArgs& args,
const ParallelArgs& parallel_args) {
param.isBF16 = args.dtype() == "bfloat16";
param.rmsNormEps = args.rms_norm_eps();
param.worldSize = parallel_args.world_size();
param.numAttentionHeadsPerRank =
args.mm_num_attention_heads() / param.worldSize;
param.hiddenSizePerAttentionHead =
args.mm_hidden_size() / args.mm_num_attention_heads();
std::optional<long int> optionalValue = args.mm_num_attention_heads();
param.numKeyValueHeadsPerRank =
static_cast<int>(optionalValue.value()) / param.worldSize;
param.rank = parallel_args.rank();
param.backend = "lccl";
param.enableLogN = false;
}

NpuQwen2VisionEncoderLayerImpl::NpuQwen2VisionEncoderLayerImpl(
const ModelContext& context)
: NpuBaseLayer(context) {
auto model_args = context.get_model_args();
auto parallel_args = context.get_parallel_args();
auto options = context.get_tensor_options();
param_from_args(encode_param_, model_args, parallel_args);
at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER);
atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER);
dtype_ = c10::typeMetaToScalarType(options.dtype());
device_id_ = options.device().index();
placeholder_ = atb_speed::Utils::AtTensor2Tensor(
torch::zeros({1}).to(device_).to(dtype_));
at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_);
for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
at_weight_tensors_[i] = torch::zeros({1}).to(options);
}
}

void NpuQwen2VisionEncoderLayerImpl::verify_loaded_weights() const {
for (const auto& [index, name] : WEIGHT_MAPPING) {
CHECK(at_weight_tensors_[index].sizes() != std::vector<int64_t>({1}))
<< "weight is not loaded for " << name;
}
}

void NpuQwen2VisionEncoderLayerImpl::merge_loaded_weights() {
// spilt pack qkv weight when enable tp
get_weights_col_packed_qkv();
if (encode_param_.worldSize > 1) {
// merge qkv weight
auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT],
at_weight_tensors_[IN_VISION_K_WEIGHT],
at_weight_tensors_[IN_VISION_V_WEIGHT]},
0);
at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_);

// merge qkv bias
auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS],
at_weight_tensors_[IN_VISION_K_BIAS],
at_weight_tensors_[IN_VISION_V_BIAS]},
0);
at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias;
at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_);
}
c10_npu::NPUCachingAllocator::emptyCache();
for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
atb_weight_tensors_[i] =
atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]);
}

init_layer();
}
// tp spilt weight
void NpuQwen2VisionEncoderLayerImpl::get_weights_col_packed_qkv() {
int rank = encode_param_.rank;
int worldSize = encode_param_.worldSize;
// split qkv weight
qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0);
qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0);
// weight
at_weight_tensors_[IN_VISION_Q_WEIGHT] =
(qkv_weight[0].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_K_WEIGHT] =
(qkv_weight[1].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_V_WEIGHT] =
(qkv_weight[2].chunk(worldSize, 0))[rank];
// bias
at_weight_tensors_[IN_VISION_Q_BIAS] =
(qkv_bias[0].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_K_BIAS] =
(qkv_bias[1].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_V_BIAS] =
(qkv_bias[2].chunk(worldSize, 0))[rank];
}

void NpuQwen2VisionEncoderLayerImpl::load_state_dict(
const StateDict& state_dict) {
for (const auto& [index, name] : WEIGHT_MAPPING) {
if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) {
set_weight(state_dict, name, index, WEIGHT_SHARD[index]);
} else {
set_weight(state_dict, name, index);
}
}
}

int64_t NpuQwen2VisionEncoderLayerImpl::init_layer() {
name_ = "qwen2_encoder_layer";
model_name_ = "qwen2_vl";
CHECK_OPERATION_STATUS_RETURN(init_node(encode_node_, encode_param_));
return atb::NO_ERROR;
}

int64_t NpuQwen2VisionEncoderLayerImpl::init_node(
atb_speed::Model::Node& node,
atb_speed::qwen::VisionEncoderLayerParam& param) {
atb::Operation* operation = nullptr;
atb_speed::qwen::Qwen3VL_EncoderLayer(param, &operation);
node.operation.reset(operation);
if (node.operation == nullptr) {
LOG(ERROR) << "node.operation is null";
return -1;
}
if (node.operation->GetInputNum() < 1) {
LOG(ERROR) << "Can not resize number which is smaller than 1";
return -1;
}
node.inTensors.resize(node.operation->GetInputNum());
node.outTensors.resize(1);
size_t inTensorId = 1;

for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER;
++weightTensorId) {
node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId];
}

node.variantPack.inTensors.reserve(node.inTensors.size());
node.variantPack.inTensors.resize(node.inTensors.size());
node.variantPack.outTensors.reserve(1);
node.variantPack.outTensors.resize(1);
return atb::NO_ERROR;
}

torch::Tensor NpuQwen2VisionEncoderLayerImpl::forward(
torch::Tensor& x,
torch::Tensor& cos_pos,
torch::Tensor& sin_pos,
torch::Tensor& cu_seqlen,
std::vector<int>& cu_seqlen_vec,
ModelInputParams& input_params,
int node_id,
aclrtEvent* event,
std::atomic<bool>* event_flag) {
atb::Status st;

build_node_variant_pack(encode_node_,
x,
cos_pos,
sin_pos,
cu_seqlen,
cu_seqlen_vec,
input_params,
true);
// mstxRangeEnd(id);
st = execute_node(encode_node_, node_id);
LOG_IF(FATAL, st != 0) << model_name_
<< "excute encode layer fail, error code: " << st;
return x;
}

void NpuQwen2VisionEncoderLayerImpl::build_node_variant_pack(
atb_speed::Model::Node& node,
torch::Tensor& x,
torch::Tensor& cos_pos,
torch::Tensor& sin_pos,
torch::Tensor& cu_seqlen,
std::vector<int>& cu_seqlen_vec,
ModelInputParams& input_params,
bool is_prefill) {
internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x);

node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_;
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) =
atb_speed::Utils::AtTensor2Tensor(cos_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) =
atb_speed::Utils::AtTensor2Tensor(sin_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) =
atb_speed::Utils::AtTensor2Tensor(cu_seqlen);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3).hostData =
cu_seqlen_vec.data();

for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
CHECK_THROW(node.inTensors.at(i) == nullptr,
model_name_ << "inTensor " << i << "is NULL");
node.variantPack.inTensors.at(i) = *node.inTensors.at(i);
// LOG(INFO) << model_name_ << "inTensors[" << i << "]:"
// << atb_speed::TensorUtil::TensorToString(
// node.variantPack.inTensors.at(i));
}

node.variantPack.outTensors.at(0) = internal_tensors_;
}

} // namespace layer
} // namespace xllm
Loading
Loading