diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 6d56faccd..4668991f3 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -59,6 +59,7 @@ cc_library( llama_decoder_layer.h multi_head_attention.h qwen2_decoder_layer.h + qwen2_vision_encode_layer.h qwen2dot5_vision_decode_layer.h qwen3_vision_encode_layer.h qwen3_decoder_layer.h diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index 61f7759d9..8d85994c4 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -9,6 +9,7 @@ cc_library( npu_word_embedding_impl.h npu_pos_embedding_impl.h npu_lm_head_impl.h + npu_qwen2_vision_encoder_layer_impl.h npu_qwen2dot5_vision_encoder_layer_impl.h npu_qwen3_vision_encoder_layer_impl.h npu_qwen3_moe_decoder_layer_impl.h @@ -29,6 +30,7 @@ cc_library( npu_word_embedding_impl.cpp npu_pos_embedding_impl.cpp npu_lm_head_impl.cpp + npu_qwen2_vision_encoder_layer_impl.cpp npu_qwen2dot5_vision_encoder_layer_impl.cpp npu_qwen3_vision_encoder_layer_impl.cpp npu_qwen3_moe_decoder_layer_impl.cpp diff --git a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp new file mode 100644 index 000000000..e29781413 --- /dev/null +++ b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp @@ -0,0 +1,285 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_qwen2_vision_encoder_layer_impl.h" + +#include +#include + +#include +#include + +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" +#include "xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h" + +namespace xllm { +namespace layer { + +enum VisionEncoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS, + IN_POST_NORM_WEIGHT, + IN_POST_NORM_BIAS, + IN_QKV_WEIGHT, + IN_QKV_BIAS, + IN_WATTENTION_OUT_WEIGHT, + IN_WATTENTION_OUT_BIAS, + IN_LINEAR_FC1_WEIGHT, + IN_LINEAR_FC1_BIAS, + IN_LINEAR_FC2_WEIGHT, + IN_LINEAR_FC2_BIAS, + IN_VISION_Q_WEIGHT, + IN_VISION_Q_BIAS, + IN_VISION_K_WEIGHT, + IN_VISION_K_BIAS, + IN_VISION_V_WEIGHT, + IN_VISION_V_BIAS +}; + +const uint64_t WEIGHT_COUNT_PER_LAYER = 18; + +static std::vector> WEIGHT_MAPPING = { + {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, + {IN_INPUT_NORM_BIAS, "norm1.bias"}, + {IN_POST_NORM_WEIGHT, "norm2.weight"}, + {IN_POST_NORM_BIAS, "norm2.bias"}, + {IN_QKV_WEIGHT, "attn.qkv.weight"}, + {IN_QKV_BIAS, "attn.qkv.bias"}, + {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, + {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, + {IN_LINEAR_FC1_WEIGHT, "mlp.fc1.weight"}, + {IN_LINEAR_FC1_BIAS, "mlp.fc1.bias"}, + {IN_LINEAR_FC2_WEIGHT, "mlp.fc2.weight"}, + {IN_LINEAR_FC2_BIAS, "mlp.fc2.bias"}}; + +// {weight,dim} +static std::map WEIGHT_SHARD = { + {IN_WATTENTION_OUT_WEIGHT, 1}, + {IN_LINEAR_FC1_WEIGHT, 0}, + {IN_LINEAR_FC1_BIAS, 0}, + {IN_LINEAR_FC2_WEIGHT, 1}, +}; + +void NpuQwen2VisionEncoderLayerImpl::param_from_args( + atb_speed::qwen::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args) { + param.isBF16 = args.dtype() == "bfloat16"; + param.rmsNormEps = args.rms_norm_eps(); + param.worldSize = parallel_args.world_size(); + param.numAttentionHeadsPerRank = + args.mm_num_attention_heads() / param.worldSize; + param.hiddenSizePerAttentionHead = + args.mm_hidden_size() / args.mm_num_attention_heads(); + std::optional optionalValue = args.mm_num_attention_heads(); + param.numKeyValueHeadsPerRank = + static_cast(optionalValue.value()) / param.worldSize; + param.rank = parallel_args.rank(); + param.backend = "lccl"; + param.enableLogN = false; +} + +NpuQwen2VisionEncoderLayerImpl::NpuQwen2VisionEncoderLayerImpl( + const ModelContext& context) + : NpuBaseLayer(context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + param_from_args(encode_param_, model_args, parallel_args); + at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + dtype_ = c10::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + placeholder_ = atb_speed::Utils::AtTensor2Tensor( + torch::zeros({1}).to(device_).to(dtype_)); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void NpuQwen2VisionEncoderLayerImpl::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void NpuQwen2VisionEncoderLayerImpl::merge_loaded_weights() { + // spilt pack qkv weight when enable tp + get_weights_col_packed_qkv(); + if (encode_param_.worldSize > 1) { + // merge qkv weight + auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], + at_weight_tensors_[IN_VISION_K_WEIGHT], + at_weight_tensors_[IN_VISION_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; + at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); + + // merge qkv bias + auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], + at_weight_tensors_[IN_VISION_K_BIAS], + at_weight_tensors_[IN_VISION_V_BIAS]}, + 0); + at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; + at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); + } + c10_npu::NPUCachingAllocator::emptyCache(); + for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + + init_layer(); +} +// tp spilt weight +void NpuQwen2VisionEncoderLayerImpl::get_weights_col_packed_qkv() { + int rank = encode_param_.rank; + int worldSize = encode_param_.worldSize; + // split qkv weight + qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); + qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); + // weight + at_weight_tensors_[IN_VISION_Q_WEIGHT] = + (qkv_weight[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_WEIGHT] = + (qkv_weight[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_WEIGHT] = + (qkv_weight[2].chunk(worldSize, 0))[rank]; + // bias + at_weight_tensors_[IN_VISION_Q_BIAS] = + (qkv_bias[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_BIAS] = + (qkv_bias[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_BIAS] = + (qkv_bias[2].chunk(worldSize, 0))[rank]; +} + +void NpuQwen2VisionEncoderLayerImpl::load_state_dict( + const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +int64_t NpuQwen2VisionEncoderLayerImpl::init_layer() { + name_ = "qwen2_encoder_layer"; + model_name_ = "qwen2_vl"; + CHECK_OPERATION_STATUS_RETURN(init_node(encode_node_, encode_param_)); + return atb::NO_ERROR; +} + +int64_t NpuQwen2VisionEncoderLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::qwen::VisionEncoderLayerParam& param) { + atb::Operation* operation = nullptr; + atb_speed::qwen::Qwen3VL_EncoderLayer(param, &operation); + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null"; + return -1; + } + if (node.operation->GetInputNum() < 1) { + LOG(ERROR) << "Can not resize number which is smaller than 1"; + return -1; + } + node.inTensors.resize(node.operation->GetInputNum()); + node.outTensors.resize(1); + size_t inTensorId = 1; + + for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER; + ++weightTensorId) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + return atb::NO_ERROR; +} + +torch::Tensor NpuQwen2VisionEncoderLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id, + aclrtEvent* event, + std::atomic* event_flag) { + atb::Status st; + + build_node_variant_pack(encode_node_, + x, + cos_pos, + sin_pos, + cu_seqlen, + cu_seqlen_vec, + input_params, + true); + // mstxRangeEnd(id); + st = execute_node(encode_node_, node_id); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute encode layer fail, error code: " << st; + return x; +} + +void NpuQwen2VisionEncoderLayerImpl::build_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = + atb_speed::Utils::AtTensor2Tensor(cos_pos); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = + atb_speed::Utils::AtTensor2Tensor(sin_pos); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = + atb_speed::Utils::AtTensor2Tensor(cu_seqlen); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3).hostData = + cu_seqlen_vec.data(); + + for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + model_name_ << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + // LOG(INFO) << model_name_ << "inTensors[" << i << "]:" + // << atb_speed::TensorUtil::TensorToString( + // node.variantPack.inTensors.at(i)); + } + + node.variantPack.outTensors.at(0) = internal_tensors_; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h new file mode 100644 index 000000000..5d449caab --- /dev/null +++ b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h @@ -0,0 +1,124 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include + +#include "atb/atb_infer.h" +#include "atb_speed/base/hosttensor_binder.h" +#include "atb_speed/base/model.h" +#include "atb_speed/log.h" +#include "atb_speed/utils/model_factory.h" +#include "core/framework/model/model_args.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/state_dict/state_dict.h" +#include "nlohmann/json.hpp" +#include "npu_base_layer.h" +#include "pytorch/adapter/utils/utils.h" +#include "xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h" + +namespace xllm { +namespace layer { + +class NpuQwen2VisionEncoderLayerImpl : public NpuBaseLayer { + public: + explicit NpuQwen2VisionEncoderLayerImpl(const ModelContext& context); + + ~NpuQwen2VisionEncoderLayerImpl() {}; + + void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights() const override; + + void merge_loaded_weights() override; + + int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + int node_id = 0, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr); + + private: + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& cu_seqlen, + std::vector& cu_seqlen_vec, + ModelInputParams& input_params, + bool is_prefill); + + void get_weights_col_packed_qkv(); + + void param_from_args(atb_speed::qwen::VisionEncoderLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::qwen::VisionEncoderLayerParam& param); + + void pad_qkv_weights(); + + void pad_mlp_weights(); + + torch::Tensor pad_tensor(const torch::Tensor& tensor, + int64_t target_shape, + int64_t dim = 0) { + int64_t pad_size = target_shape - tensor.size(dim); + if (tensor.dim() == 1) { + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size})); + } else if (tensor.dim() == 2) { + if (1 == dim) + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size, 0, 0})); + else + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + } + return tensor; + } + + atb_speed::Model::Node encode_node_; + std::string model_name_; + + atb_speed::qwen::VisionEncoderLayerParam encode_param_; + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + at::Tensor cu_seqlen_; + at::Tensor at_placeholder_; + std::vector qkv_weight; + std::vector qkv_bias; + int device_id_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/qwen2_vision_encode_layer.h b/xllm/core/layers/qwen2_vision_encode_layer.h new file mode 100644 index 000000000..8e88d83fb --- /dev/null +++ b/xllm/core/layers/qwen2_vision_encode_layer.h @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(USE_NPU) +#include "npu/npu_qwen2_vision_encoder_layer_impl.h" +#endif + +namespace xllm { +namespace layer { + +#if defined(USE_NPU) +class Qwen2VisionEncoderLayer + : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = NpuQwen2VisionEncoderLayerImpl; + + Qwen2VisionEncoderLayer(const ModelContext& context) + : ModuleHolder( + std::make_shared(context)) {} +}; +#endif + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/models/models.h b/xllm/models/models.h index 0460d6ff5..decdaf25a 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -35,6 +35,8 @@ limitations under the License. #include "llm/qwen3_embedding.h" // IWYU pragma: keep #include "vlm/minicpmv.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen2_vl.h" // IWYU pragma: keep +#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #elif defined(USE_MLU) diff --git a/xllm/models/vlm/qwen2_vl.h b/xllm/models/vlm/qwen2_vl.h new file mode 100644 index 000000000..02f2e1ba3 --- /dev/null +++ b/xllm/models/vlm/qwen2_vl.h @@ -0,0 +1,622 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/layers/lm_head.h" +#include "core/layers/qwen2_decoder_layer.h" +#include "core/layers/qwen2_vision_encode_layer.h" +#include "core/layers/rms_norm.h" +#include "models/llm/qwen2.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "qwen2_5_vl.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +#define PrintTensor(tensor) print_tensor(tensor, #tensor, 10, true, false); + +class Qwen2_VisionBlockImpl : public torch::nn::Module { + public: + Qwen2_VisionBlockImpl(const ModelContext& context) { + // register submodules + encoder_layer_ = register_module("encoder_layer", + layer::Qwen2VisionEncoderLayer(context)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + return encoder_layer_(x, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params, + node_id); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + encoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + encoder_layer_->verify_loaded_weights(); + } + void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } + + private: + layer::Qwen2VisionEncoderLayer encoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen2_VisionBlock); + +class Qwen2_VisionPatchEmbedImpl : public torch::nn::Module { + public: + Qwen2_VisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(false))); + + proj_->weight.set_data(proj_->weight.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + } + + private: + bool proj_weight_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(Qwen2_VisionPatchEmbed); + +class Qwen2_VisionRotaryEmbeddingImpl : public torch::nn::Module { + public: + Qwen2_VisionRotaryEmbeddingImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + dim_ = model_args.mm_head_dim() / 2; + theta_ = 10000.0; + + auto opts = options.dtype(torch::kFloat32); + auto inv_freq = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, opts) / dim_); + inv_freq_ = register_buffer("inv_freq", inv_freq); + } + + void update_freqs_cache(int64_t seqlen) { + if (seqlen <= seq_len_cached_) return; + + seqlen *= 2; + seq_len_cached_ = seqlen; + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(inv_freq_.device()); + inv_freq_ = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, options) / dim_); + auto seq = torch::arange(seqlen, options); + freqs_cached_ = torch::outer(seq, inv_freq_); + } + + torch::Tensor forward(int seqlen) { + update_freqs_cache(seqlen); + return freqs_cached_.slice(0, 0, seqlen); + } + + private: + int dim_ = 0; + double theta_ = 0.0; + + int64_t seq_len_cached_ = 0; + torch::Tensor inv_freq_; + torch::Tensor freqs_cached_; +}; +TORCH_MODULE(Qwen2_VisionRotaryEmbedding); + +class Qwen2_VisionPatchMergerImpl : public torch::nn::Module { + public: + Qwen2_VisionPatchMergerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto quant_args = context.get_quant_args(); + auto parallel_args = context.get_parallel_args(); + + int64_t d_model = model_args.mm_projection_dim(); // out_hidden_size + int context_dim = model_args.mm_hidden_size(); + int spatial_merge_size = model_args.mm_spatial_merge_size(); + + hidden_size_ = + context_dim * static_cast(std::pow(spatial_merge_size, 2)); + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({context_dim}) + .elementwise_affine(true) + .eps(1e-6))); + norm_->weight.set_data(norm_->weight.to(options)); + norm_->bias.set_data(norm_->bias.to(options)); + auto fc1 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, hidden_size_).bias(true)); + fc1->weight.set_data(fc1->weight.to(options)); + fc1->bias.set_data(fc1->bias.to(options)); + auto act = torch::nn::GELU(); + auto fc2 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, d_model).bias(true)); + fc2->weight.set_data(fc2->weight.to(options)); + fc2->bias.set_data(fc2->bias.to(options)); + mlp_ = register_module("mlp", torch::nn::Sequential(fc1, act, fc2)); + layers_ = std::make_tuple(fc1, act, fc2); + } + + torch::Tensor forward(torch::Tensor x) { + x = norm_(x).view({-1, hidden_size_}); + return mlp_->forward(x); + } + + void load_state_dict(const StateDict& state_dict) { + const auto& norm_dict = state_dict.get_dict_with_prefix("ln_q."); + const auto& norm_weight = norm_dict.get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(norm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + norm_->weight.data().copy_(norm_weight); + is_norm_weight_loaded = true; + } + const auto norm_bias = norm_dict.get_tensor("bias"); + if (norm_bias.defined()) { + CHECK_EQ(norm_->bias.sizes(), norm_bias.sizes()) + << "bias size mismatch for " << name(); + norm_->bias.data().copy_(norm_bias); + is_norm_bias_loaded = true; + } + const auto& fc1_dict = state_dict.get_dict_with_prefix("mlp.0."); + const auto& fc1_weight = fc1_dict.get_tensor("weight"); + if (fc1_weight.defined()) { + CHECK_EQ(std::get<0>(layers_)->weight.sizes(), fc1_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<0>(layers_)->weight.data().copy_(fc1_weight); + is_fc1_weight_loaded = true; + } + const auto fc1_bias = fc1_dict.get_tensor("bias"); + if (fc1_bias.defined()) { + CHECK_EQ(std::get<0>(layers_)->bias.sizes(), fc1_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<0>(layers_)->bias.data().copy_(fc1_bias); + is_fc1_bias_loaded = true; + } + + const auto& fc2_dict = state_dict.get_dict_with_prefix("mlp.2."); + const auto& fc2_weight = fc2_dict.get_tensor("weight"); + if (fc2_weight.defined()) { + CHECK_EQ(std::get<2>(layers_)->weight.sizes(), fc2_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<2>(layers_)->weight.data().copy_(fc2_weight); + is_fc2_weight_loaded = true; + } + const auto fc2_bias = fc2_dict.get_tensor("bias"); + if (fc2_bias.defined()) { + CHECK_EQ(std::get<2>(layers_)->bias.sizes(), fc2_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<2>(layers_)->bias.data().copy_(fc2_bias); + is_fc2_bias_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_fc1_weight_loaded) + << "weight is not loaded for " << prefix + "mlp.0" + ".weight"; + CHECK(is_fc1_bias_loaded) + << "bias is not loaded for " << prefix + "mlp.0" + ".bias"; + CHECK(is_fc2_weight_loaded) + << "weight is not loaded for " << prefix + "mlp.2" + ".weight"; + CHECK(is_fc2_bias_loaded) + << "bias is not loaded for " << prefix + "mlp.2" + ".bias"; + CHECK(is_norm_weight_loaded) + << "weight is not loaded for " << prefix + "ln_q" + ".weight"; + CHECK(is_norm_bias_loaded) + << "bias is not loaded for " << prefix + "ln_q" + ".bias"; + } + + private: + int hidden_size_; + torch::nn::LayerNorm norm_{nullptr}; + torch::nn::Sequential mlp_{nullptr}; + std::tuple layers_ = { + nullptr, + nullptr, + nullptr}; + bool is_fc1_weight_loaded = false; + bool is_fc1_bias_loaded = false; + bool is_fc2_weight_loaded = false; + bool is_fc2_bias_loaded = false; + bool is_norm_weight_loaded = false; + bool is_norm_bias_loaded = false; +}; +TORCH_MODULE(Qwen2_VisionPatchMerger); + +class Qwen2_VisionTransformerImpl : public torch::nn::Module { + public: + Qwen2_VisionTransformerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + hidden_size_ = model_args.mm_hidden_size(); + num_heads_ = model_args.mm_num_attention_heads(); + + window_size_ = model_args.mm_window_size(); + patch_size_ = model_args.mm_patch_size(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + spatial_merge_unit_ = static_cast(std::pow(spatial_merge_size_, 2)); + // mlp_ratio_ = model_args.mm_mlp_ratio(); + + patch_embed_ = + register_module("patch_embed", Qwen2_VisionPatchEmbed(context)); + rotary_pos_emb_ = + register_module("rotary_pos_emb", Qwen2_VisionRotaryEmbedding(context)); + blocks_ = register_module("blocks", torch::nn::ModuleList()); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = Qwen2_VisionBlock(context); + blocks_->push_back(block); + layers_.push_back(block); + } + merger_ = register_module("merger", Qwen2_VisionPatchMerger(context)); + } + + torch::Tensor rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + + auto grid_thw_cpu = grid_thw.cpu(); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + pos_ids_vec.push_back( + torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = + grid_thw + .index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}) + .max(); + + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return rotary_pos_emb; + } + + torch::Tensor forward(torch::Tensor hidden_states, + torch::Tensor grid_thw, // [batch,thw] + const ModelInputParams& input_params) { + // patchify + // hidden_states = x.to(device=self.device, dtype=self.dtype); + hidden_states = patch_embed_(hidden_states); + // compute position embedding + auto rotary_pos_emb = rot_pos_emb(grid_thw); + auto seq_len = hidden_states.sizes()[0]; + + // compute cu_seqlens + auto cu_seqlens = torch::repeat_interleave( + grid_thw.index({torch::indexing::Slice(), 1}) * + grid_thw.index({torch::indexing::Slice(), 2}), + grid_thw.index({torch::indexing::Slice(), 0})) + .cumsum(0, torch::kInt32); + namespace F = torch::nn::functional; + cu_seqlens = F::pad( + cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + + // transformers + cu_seqlens = torch::diff(cu_seqlens); + m_cos = rotary_pos_emb.cos().type_as(hidden_states); + m_cos = m_cos.repeat({1, 2}); + m_sin = rotary_pos_emb.sin().type_as(hidden_states); + m_sin = m_sin.repeat({1, 2}); + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), // full seqlen vec + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + hidden_states = layers_[idx](hidden_states, + m_cos, + m_sin, + cu_seqlens, + cu_seqlens_vec, + input_params_new, + idx); + } + // adapter + hidden_states = merger_(hidden_states); + return hidden_states; + } + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + } + + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->verify_loaded_weights(prefix + "blocks." + + std::to_string(idx) + "."); + } + merger_->verify_loaded_weights(prefix + "merger."); + } + + void merge_loaded_weights() { + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->merge_loaded_weights(); + } + } + + private: + int hidden_size_ = 0; + int num_heads_ = 0; + int window_size_ = 0; + int patch_size_ = 0; + int spatial_merge_size_ = 0; + std::set fullatt_block_indexes_; + int spatial_merge_unit_ = 0; + // int mlp_ratio_ = 4; + + Qwen2_VisionPatchEmbed patch_embed_{nullptr}; + Qwen2_VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + Qwen2_VisionPatchMerger merger_{nullptr}; + + torch::Tensor m_cos; + torch::Tensor m_sin; + int device_id = 0; +}; +TORCH_MODULE(Qwen2_VisionTransformer); + +struct Qwen2_VLImageInputs { + torch::Tensor pixel_values; + torch::Tensor image_grid_thw; +}; + +struct Qwen2_VLVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; + torch::Tensor second_per_grid_ts; +}; + +class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { + public: + Qwen2_VLForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen2_VisionTransformer(context)); + + language_model_ = + register_module("language_model", QWen2ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto image_embeds = visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen2_VLImageInputs{pixel_values, image_grid_thw}; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); + } + // verify + visual_->verify_loaded_weights("visual."); + visual_->merge_loaded_weights(); + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader)); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + + Qwen2_VisionTransformer visual_{nullptr}; + QWen2ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen2_VLForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(qwen2_vl, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen2_vl, Qwen2_VLForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen2_vl, Qwen2VLImageProcessor); + +REGISTER_MODEL_ARGS(qwen2_vl, [&] { + // text config + // LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 151643); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151645); + LOAD_ARG_OR(vision_start_token_id, "vision_start_token_id", 151652); + LOAD_ARG_OR(vision_end_token_id, "vision_end_token_id", 151653); + LOAD_ARG_OR(vision_token_id, "vision_token_id", 151654); + LOAD_ARG_OR(image_token_id, "image_token_id", 151655); + LOAD_ARG_OR(video_token_id, "video_token_id", 151656); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + // LOAD_ARG_OR(initializer_range, "initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + LOAD_ARG_OR(model_type, "model_type", "qwen2_vl"); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 4); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-06); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + LOAD_ARG_OR(sliding_window, "sliding_window", 32768); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16"); + // LOAD_ARG_OR(transformers_version, "transformers_version", "4.41.2"); + // LOAD_ARG_OR(use_cache, "use_cache", true); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + // vision_config + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 32); + // LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "silu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.embed_dim", 1280); + // LOAD_ARG_OR(mm_mlp_ratio, "vision_config.mlp_ratio", 4); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_chans", 3); + LOAD_ARG_OR(mm_projection_dim, "vision_config.hidden_size", 3584); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 14); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_spatial_patch_size, "vision_config.spatial_patch_size", 14); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + LOAD_ARG_OR(rope_scaling_rope_type, "rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, "rope_scaling.mrope_section"); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); +}); +} // namespace xllm diff --git a/xllm/models/vlm/qwen2_vl_embedding.h b/xllm/models/vlm/qwen2_vl_embedding.h new file mode 100644 index 000000000..2227640fc --- /dev/null +++ b/xllm/models/vlm/qwen2_vl_embedding.h @@ -0,0 +1,187 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/framework/model/embedding_vlm.h" +#include "models/llm/qwen2.h" +#include "models/vlm/qwen2_vl.h" + +namespace xllm { + +class Qwen2_VLForEmbeddingImpl : public torch::nn::Module { + public: + Qwen2_VLForEmbeddingImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen2_VisionTransformer(context)); + language_model_ = + register_module("language_model", QWen2ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto image_embeds = visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen2_VLImageInputs{pixel_values, image_grid_thw}; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + return emb; + } + + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + auto pooler_output = torch::nn::functional::normalize( + h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); + return pooler_output; + } + + torch::Tensor logits(const torch::Tensor&, const torch::Tensor&) { + LOG(ERROR) << "logits() not implemented for Embedding Model!"; + return torch::empty({0}); + } + + torch::Device device() const { return options_.device(); } + + const torch::TensorOptions& options() const { return options_; } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); + } + // verify + visual_->verify_loaded_weights("visual."); + visual_->merge_loaded_weights(); + // if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader)); + // } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + + Qwen2_VisionTransformer visual_{nullptr}; + QWen2ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen2_VLForEmbedding); + +template <> +class EmbeddingVLMImpl : public EmbeddingVLM { + public: + EmbeddingVLMImpl(xllm::Qwen2_VLForEmbedding model, + const torch::TensorOptions& options) + : model_(std::move(model)), options_(options) {} + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& parameters) override { + return model_->forward(tokens, positions, kv_caches, parameters); + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) override { + return model_->logits(hidden_states, seleted_idxes); + } + + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) override { + return model_->pooler(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) override { + model_->load_model(std::move(loader)); + } + + torch::Device device() const override { return model_->device(); } + + const torch::TensorOptions& options() const override { + return model_->options(); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + // Delegate head/embedding accessors to underlying model implementation. + layer::LmHead get_lm_head() override { return model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); } + layer::WordEmbedding get_word_embedding() override { + return model_->get_word_embedding(); + } + void set_word_embedding(layer::WordEmbedding& embedding) override { + model_->set_word_embedding(embedding); + } + + private: + xllm::Qwen2_VLForEmbedding model_; + torch::TensorOptions options_; +}; + +REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(qwen2_vl_embedding, + qwen2_vl, + Qwen2_VLForEmbedding); +} // namespace xllm \ No newline at end of file