@@ -23,8 +23,16 @@ namespace kernels {
2323namespace xpu {
2424
2525void MatchMatrixTensorCompute::PrepareForRun () {
26- wx_max_xpu_guard_ =
27- TargetWrapperXPU::MallocScratchPad (XPU_MAX_LOD_SIZE * sizeof (int ));
26+ auto & param = this ->Param <param_t >();
27+ float w_max = param.__xpu__w_max ;
28+ std::vector<float > w_max_v (XPU_MAXPTR_SIZE, w_max);
29+ weight_max_xpu_guard_ =
30+ TargetWrapperXPU::MallocScratchPad (XPU_MAXPTR_SIZE * sizeof (float ));
31+ XPU_CALL (xpu_memcpy (reinterpret_cast <float *>(weight_max_xpu_guard_->addr_ ),
32+ w_max_v.data (),
33+ XPU_MAXPTR_SIZE * sizeof (float ),
34+ XPUMemcpyKind::XPU_HOST_TO_DEVICE));
35+
2836 offset_l_xpu_guard_ =
2937 TargetWrapperXPU::MallocScratchPad (XPU_MAX_LOD_SIZE * sizeof (int ));
3038 offset_r_xpu_guard_ =
@@ -44,7 +52,6 @@ void MatchMatrixTensorCompute::Run() {
4452 auto * out = param.out ;
4553 auto * tmp = param.tmp ;
4654 int dim_t = param.dim_t ;
47- float w_max = param.__xpu__w_max ;
4855 bool fuse_relu = param.fuse_relu ;
4956 bool float_to_fix = param.__xpu__float_to_fix ;
5057 CHECK (float_to_fix) << " W should be fixed point" ;
@@ -74,44 +81,15 @@ void MatchMatrixTensorCompute::Run() {
7481 auto * bottom_l_trans_data = tmp->mutable_data <float >(TARGET (kXPU ));
7582 int batch_size = x->lod ()[0 ].size () - 1 ;
7683
77- float * wx_max = reinterpret_cast <float *>(wx_max_xpu_guard_ ->addr_ );
84+ float * w_max = reinterpret_cast <float *>(weight_max_xpu_guard_ ->addr_ );
7885 int * offset_l_xpu = reinterpret_cast <int *>(offset_l_xpu_guard_->addr_ );
7986 int * offset_r_xpu = reinterpret_cast <int *>(offset_r_xpu_guard_->addr_ );
8087
81- int r = xdnn::gemm_int16_tmp_api<float , int16_t , float >(
82- ctx.GetRawContext (), /* ctx */
83- false , /* trans_a */
84- false , /* trans_b */
85- x->dims ()[0 ], /* m */
86- dim_t * dim_in, /* n */
87- dim_in, /* k */
88- 1 .0f , /* alpha */
89- bottom_l_data, /* data_a */
90- dim_in, /* lda */
91- w_data, /* data_b */
92- dim_t * dim_in, /* ldb */
93- 0 .0f , /* beta */
94- bottom_l_trans_data, /* data_c */
95- dim_t * dim_in, /* ldc */
96- nullptr , /* bias */
97- xdnn::Activation_t::LINEAR, /* act */
98- 0 .0f , /* max_a */
99- w_max, /* max_b */
100- wx_max /* max_c */ );
101- CHECK_EQ (r, 0 );
102-
103- int max_width = 0 ;
10488 for (int i = 0 ; i < offset_l.size (); ++i) {
10589 offset_l_cpu[i] = offset_l[i];
106- if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1 ] > max_width)) {
107- max_width = offset_l_cpu[i] - offset_l_cpu[i - 1 ];
108- }
10990 }
11091 for (int i = 0 ; i < offset_r.size (); ++i) {
11192 offset_r_cpu[i] = offset_r[i];
112- if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1 ] > max_width)) {
113- max_width = offset_r_cpu[i] - offset_r_cpu[i - 1 ];
114- }
11593 }
11694 XPU_CALL (xpu_memcpy (offset_l_xpu,
11795 offset_l_cpu.get (),
@@ -122,20 +100,23 @@ void MatchMatrixTensorCompute::Run() {
122100 offset_r.size () * sizeof (int ),
123101 XPUMemcpyKind::XPU_HOST_TO_DEVICE));
124102
125- r = xdnn::match_matrix_tensor (ctx.GetRawContext (),
126- batch_size,
127- bottom_l_trans_data,
128- bottom_r_data,
129- offset_l_xpu,
130- offset_r_xpu,
131- dim_t ,
132- dim_in,
133- out_data,
134- wx_max,
135- act,
136- max_width);
103+ int r = xdnn::match_matrix_tensor<float , int16_t , int >(
104+ ctx.GetRawContext (),
105+ bottom_l_data,
106+ bottom_r_data,
107+ w_data,
108+ out_data,
109+ dim_in,
110+ dim_t ,
111+ true , // the weight is trans in XPUMmdnnFloat2Fix
112+ {offset_l_cpu.get (), static_cast <int >(offset_l.size ()), offset_l_xpu},
113+ {offset_r_cpu.get (), static_cast <int >(offset_r.size ()), offset_r_xpu},
114+ nullptr ,
115+ nullptr ,
116+ w_max,
117+ act,
118+ bottom_l_trans_data);
137119 CHECK_EQ (r, 0 );
138-
139120 int lod_lv1_size = batch_size * dim_t ;
140121 int lod_lv2_size = x->lod ()[0 ].back () * dim_t ;
141122 std::vector<size_t > out_lod0 (batch_size + 1 , 0 );
0 commit comments