Skip to content

Commit 12fb27d

Browse files
committed
[XPU] change match_matrix_tensor op from old version to refector verison, test=develop, test=xpu
1 parent 85b8a9b commit 12fb27d

File tree

4 files changed

+70
-68
lines changed

4 files changed

+70
-68
lines changed

lite/backends/xpu/target_wrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace lite {
3434
const int XPU_MAX_LOD_SIZE = 32;
3535
// MAX(lod[i + 1] - lod[i]) = 512
3636
const int XPU_MAX_LOD_SEQ_LEN = 512;
37+
const int XPU_MAXPTR_SIZE = 6;
3738

3839
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
3940

lite/kernels/xpu/__xpu__mmdnn_compute.cc

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,13 @@ class MMDNNMatchConvTopk {
482482
int dim_in_;
483483
int out_channel_;
484484

485-
MMDNNFcOp xw_fc_;
485+
const int16_t* match_weight_{nullptr};
486+
XPUScratchPadGuard match_weight_max_guard_;
487+
float* match_weight_max_{nullptr};
488+
XPUScratchPadGuard in_max_guard_;
489+
float* in_max_{nullptr};
490+
XPUScratchPadGuard out_max_guard_;
491+
float* out_max_{nullptr};
486492
const int16_t* conv_weight_{nullptr};
487493
float conv_weight_max_;
488494
XPUScratchPadGuard hbm_buffer_guard_;
@@ -525,12 +531,17 @@ class MMDNNMatchConvTopk {
525531
out_channel_ = out_channel;
526532
topks_ = topks;
527533

528-
xw_fc_.Init(input_w,
529-
input_w_max,
530-
nullptr,
531-
dim_t_ * dim_in_,
532-
dim_in_,
533-
xdnn::Activation_t::LINEAR);
534+
match_weight_ = input_w->data<int16_t>();
535+
match_weight_max_guard_ =
536+
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
537+
match_weight_max_ =
538+
reinterpret_cast<float*>(match_weight_max_guard_->addr_);
539+
FillMax(input_w_max, match_weight_max_);
540+
in_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
541+
in_max_ = reinterpret_cast<float*>(in_max_guard_->addr_);
542+
out_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
543+
out_max_ = reinterpret_cast<float*>(out_max_guard_->addr_);
544+
534545
conv_weight_ = conv_w->data<int16_t>();
535546
conv_weight_max_ = conv_w_max;
536547

@@ -644,21 +655,30 @@ class MMDNNMatchConvTopk {
644655
}
645656
seq_avg_topk_out = out->mutable_data<float>(TARGET(kXPU));
646657

647-
int max_width = std::max(left_seqlen_max, right_seqlen_max);
648-
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
649658
int r = 0;
650-
r = xdnn::match_matrix_tensor(ctx,
651-
batch,
652-
xw_out,
653-
right->data<float>(),
654-
left_lod_32_,
655-
right_lod_32_,
656-
dim_t_,
657-
dim_in_,
658-
xwy_out,
659-
xw_fc_.out_max,
660-
xdnn::Activation_t::RELU,
661-
max_width);
659+
r = xdnn::findmax<float>(
660+
ctx, left->data<float>(), left_seqlen_sum * dim_in_, in_max_);
661+
CHECK_EQ(r, 0);
662+
r = xdnn::match_matrix_tensor<float, int16_t, int>(
663+
ctx,
664+
left->data<float>(),
665+
right->data<float>(),
666+
match_weight_,
667+
xwy_out,
668+
dim_in_,
669+
dim_t_,
670+
true,
671+
{left_lod_32_cpu.data(),
672+
static_cast<int>(left_lod_32_cpu.size()),
673+
left_lod_32_},
674+
{right_lod_32_cpu.data(),
675+
static_cast<int>(right_lod_32_cpu.size()),
676+
right_lod_32_},
677+
in_max_,
678+
nullptr,
679+
match_weight_max_,
680+
xdnn::Activation_t::RELU,
681+
xw_out);
662682
CHECK_EQ(r, 0);
663683
r = xdnn::search_varconv<float, int16_t>(
664684
ctx,

lite/kernels/xpu/match_matrix_tensor_compute.cc

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,16 @@ namespace kernels {
2323
namespace xpu {
2424

2525
void MatchMatrixTensorCompute::PrepareForRun() {
26-
wx_max_xpu_guard_ =
27-
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
26+
auto& param = this->Param<param_t>();
27+
float w_max = param.__xpu__w_max;
28+
std::vector<float> w_max_v(XPU_MAXPTR_SIZE, w_max);
29+
weight_max_xpu_guard_ =
30+
TargetWrapperXPU::MallocScratchPad(XPU_MAXPTR_SIZE * sizeof(float));
31+
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_xpu_guard_->addr_),
32+
w_max_v.data(),
33+
XPU_MAXPTR_SIZE * sizeof(float),
34+
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
35+
2836
offset_l_xpu_guard_ =
2937
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
3038
offset_r_xpu_guard_ =
@@ -44,7 +52,6 @@ void MatchMatrixTensorCompute::Run() {
4452
auto* out = param.out;
4553
auto* tmp = param.tmp;
4654
int dim_t = param.dim_t;
47-
float w_max = param.__xpu__w_max;
4855
bool fuse_relu = param.fuse_relu;
4956
bool float_to_fix = param.__xpu__float_to_fix;
5057
CHECK(float_to_fix) << "W should be fixed point";
@@ -74,44 +81,15 @@ void MatchMatrixTensorCompute::Run() {
7481
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kXPU));
7582
int batch_size = x->lod()[0].size() - 1;
7683

77-
float* wx_max = reinterpret_cast<float*>(wx_max_xpu_guard_->addr_);
84+
float* w_max = reinterpret_cast<float*>(weight_max_xpu_guard_->addr_);
7885
int* offset_l_xpu = reinterpret_cast<int*>(offset_l_xpu_guard_->addr_);
7986
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);
8087

81-
int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
82-
ctx.GetRawContext(), /* ctx */
83-
false, /* trans_a */
84-
false, /* trans_b */
85-
x->dims()[0], /* m */
86-
dim_t * dim_in, /* n */
87-
dim_in, /* k */
88-
1.0f, /* alpha */
89-
bottom_l_data, /* data_a */
90-
dim_in, /* lda */
91-
w_data, /* data_b */
92-
dim_t * dim_in, /* ldb */
93-
0.0f, /* beta */
94-
bottom_l_trans_data, /* data_c */
95-
dim_t * dim_in, /* ldc */
96-
nullptr, /* bias */
97-
xdnn::Activation_t::LINEAR, /* act */
98-
0.0f, /* max_a */
99-
w_max, /* max_b */
100-
wx_max /* max_c */);
101-
CHECK_EQ(r, 0);
102-
103-
int max_width = 0;
10488
for (int i = 0; i < offset_l.size(); ++i) {
10589
offset_l_cpu[i] = offset_l[i];
106-
if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) {
107-
max_width = offset_l_cpu[i] - offset_l_cpu[i - 1];
108-
}
10990
}
11091
for (int i = 0; i < offset_r.size(); ++i) {
11192
offset_r_cpu[i] = offset_r[i];
112-
if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) {
113-
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
114-
}
11593
}
11694
XPU_CALL(xpu_memcpy(offset_l_xpu,
11795
offset_l_cpu.get(),
@@ -122,20 +100,23 @@ void MatchMatrixTensorCompute::Run() {
122100
offset_r.size() * sizeof(int),
123101
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
124102

125-
r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
126-
batch_size,
127-
bottom_l_trans_data,
128-
bottom_r_data,
129-
offset_l_xpu,
130-
offset_r_xpu,
131-
dim_t,
132-
dim_in,
133-
out_data,
134-
wx_max,
135-
act,
136-
max_width);
103+
int r = xdnn::match_matrix_tensor<float, int16_t, int>(
104+
ctx.GetRawContext(),
105+
bottom_l_data,
106+
bottom_r_data,
107+
w_data,
108+
out_data,
109+
dim_in,
110+
dim_t,
111+
true, // the weight is trans in XPUMmdnnFloat2Fix
112+
{offset_l_cpu.get(), static_cast<int>(offset_l.size()), offset_l_xpu},
113+
{offset_r_cpu.get(), static_cast<int>(offset_r.size()), offset_r_xpu},
114+
nullptr,
115+
nullptr,
116+
w_max,
117+
act,
118+
bottom_l_trans_data);
137119
CHECK_EQ(r, 0);
138-
139120
int lod_lv1_size = batch_size * dim_t;
140121
int lod_lv2_size = x->lod()[0].back() * dim_t;
141122
std::vector<size_t> out_lod0(batch_size + 1, 0);

lite/kernels/xpu/match_matrix_tensor_compute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MatchMatrixTensorCompute
3333
virtual void Run();
3434

3535
private:
36-
XPUScratchPadGuard wx_max_xpu_guard_;
36+
XPUScratchPadGuard weight_max_xpu_guard_;
3737
XPUScratchPadGuard offset_l_xpu_guard_;
3838
XPUScratchPadGuard offset_r_xpu_guard_;
3939

0 commit comments

Comments
 (0)