@@ -21,7 +21,6 @@ namespace {
2121using dnnl::memory;
2222using paddle::framework::ExecutionContext;
2323using paddle::framework::GradVarName;
24- using paddle::platform::MatMulV2MKLDNNHandler;
2524using phi::OneDNNContext;
2625using phi::vectorize;
2726using phi::funcs::OneDNNGetDataType;
@@ -82,6 +81,239 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
8281 return input_dims;
8382}
8483
84+ template <typename XT, typename YT, typename OT>
85+ class MatMulV2MKLDNNHandler
86+ : public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
87+ public:
88+ MatMulV2MKLDNNHandler (const ExecutionContext &ctx,
89+ const dnnl::engine engine,
90+ paddle::platform::Place cpu_place,
91+ const std::vector<int64_t > &x_org_dims,
92+ bool trans_x,
93+ const std::vector<int64_t > &y_org_dims,
94+ bool trans_y,
95+ bool is_output_fused,
96+ const std::vector<int64_t > &x_strides_override,
97+ const std::vector<int64_t > &y_strides_override)
98+ : phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
99+ cpu_place) {
100+ // M X K * K X N
101+ std::vector<int64_t > x_dims (x_org_dims);
102+ std::vector<int64_t > y_dims (y_org_dims);
103+
104+ const int MB_idx = x_dims.size () - 3 ;
105+ const int H_idx = x_dims.size () - 2 ;
106+ const int W_idx = x_dims.size () - 1 ;
107+
108+ if (trans_x) std::swap (x_dims[H_idx], x_dims[W_idx]);
109+ if (trans_y) std::swap (y_dims[H_idx], y_dims[W_idx]);
110+
111+ const memory::dim M = x_dims[H_idx];
112+ const memory::dim K = x_dims[W_idx];
113+ const memory::dim N = y_dims[W_idx];
114+
115+ std::vector<int64_t > x_strides (x_dims.size () - 3 , 1 );
116+ std::vector<int64_t > y_strides (x_dims.size () - 3 , 1 );
117+ std::vector<int64_t > out_strides (x_dims.size () - 3 , 1 );
118+ std::vector<int64_t > out_ddims (x_dims.size () - 3 , 1 );
119+
120+ x_strides.reserve (x_dims.size ());
121+ y_strides.reserve (x_dims.size ());
122+ out_strides.reserve (x_dims.size ());
123+
124+ if (!x_strides_override.empty ()) {
125+ x_strides = x_strides_override;
126+ } else {
127+ if (!trans_x) {
128+ x_strides.insert (x_strides.end (), {M * K, K, 1 });
129+ } else {
130+ x_strides.insert (x_strides.end (), {M * K, 1 , M});
131+ }
132+ }
133+
134+ if (!y_strides_override.empty ()) {
135+ y_strides = y_strides_override;
136+ } else {
137+ if (!trans_y) {
138+ y_strides.insert (y_strides.end (), {N * K, N, 1 });
139+ } else {
140+ y_strides.insert (y_strides.end (), {N * K, 1 , K});
141+ }
142+ }
143+
144+ out_strides.insert (out_strides.end (), {M * N, N, 1 });
145+ out_ddims.insert (out_ddims.end (),
146+ {std::max (x_dims[MB_idx], y_dims[MB_idx]), M, N});
147+
148+ for (int i = x_dims.size () - 4 ; i >= 0 ; --i) {
149+ out_ddims[i] = std::max (x_dims[i], y_dims[i]);
150+ if (x_strides_override.empty ()) {
151+ x_strides[i] = x_dims[i + 1 ] * x_strides[i + 1 ];
152+ }
153+ if (y_strides_override.empty ()) {
154+ y_strides[i] = y_dims[i + 1 ] * y_strides[i + 1 ];
155+ }
156+ out_strides[i] = out_ddims[i + 1 ] * out_strides[i + 1 ];
157+ }
158+
159+ // TODO(jczaja): Why not for int8??
160+ if (!phi::funcs::is_int8<OT>() && is_output_fused) {
161+ out_strides = FakeTransposeStrides (out_ddims);
162+ }
163+
164+ auto x_md =
165+ memory::desc (x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides);
166+ auto y_md =
167+ memory::desc (y_dims, phi::funcs::OneDNNGetDataType<YT>(), y_strides);
168+ auto out_md = memory::desc (
169+ out_ddims, phi::funcs::OneDNNGetDataType<OT>(), out_strides);
170+
171+ const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs (ctx);
172+
173+ this ->AcquireForwardPrimitiveDescriptor (matmul_attrs, x_md, y_md, out_md);
174+ }
175+
176+ void AppendActivation (const ExecutionContext &ctx,
177+ dnnl::post_ops &post_ops, // NOLINT
178+ float activation_scale = 1 .0f ) {
179+ const auto invalid_attribute =
180+ ctx.HasAttr (" fuse_activation" )
181+ ? ctx.Attr <std::string>(" fuse_activation" ).empty ()
182+ : true ;
183+ if (invalid_attribute) return ;
184+
185+ const auto fuse_activation = ctx.Attr <std::string>(" fuse_activation" );
186+ const auto fuse_alpha =
187+ ctx.HasAttr (" fuse_alpha" ) ? ctx.Attr <float >(" fuse_alpha" ) : 0 .0f ;
188+ const auto fuse_beta =
189+ ctx.HasAttr (" fuse_beta" ) ? ctx.Attr <float >(" fuse_beta" ) : 0 .0f ;
190+
191+ if (fuse_activation == " hard_sigmoid" ) {
192+ post_ops.append_eltwise (activation_scale,
193+ dnnl::algorithm::eltwise_linear,
194+ fuse_alpha,
195+ fuse_beta);
196+ post_ops.append_eltwise (
197+ activation_scale, dnnl::algorithm::eltwise_clip, 0 .0f , 1 .0f );
198+ } else {
199+ const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
200+ {" abs" , dnnl::algorithm::eltwise_abs},
201+ {" clip" , dnnl::algorithm::eltwise_clip},
202+ {" gelu" , dnnl::algorithm::eltwise_gelu_erf},
203+ {" gelu_erf" , dnnl::algorithm::eltwise_gelu_erf},
204+ {" gelu_tanh" , dnnl::algorithm::eltwise_gelu_tanh},
205+ {" hard_swish" , dnnl::algorithm::eltwise_hardswish},
206+ {" leaky_relu" , dnnl::algorithm::eltwise_relu},
207+ {" mish" , dnnl::algorithm::eltwise_mish},
208+ {" relu" , dnnl::algorithm::eltwise_relu},
209+ {" relu6" , dnnl::algorithm::eltwise_bounded_relu},
210+ {" sigmoid" , dnnl::algorithm::eltwise_logistic},
211+ {" sqrt" , dnnl::algorithm::eltwise_sqrt},
212+ {" swish" , dnnl::algorithm::eltwise_swish},
213+ {" tanh" , dnnl::algorithm::eltwise_tanh}};
214+
215+ const auto &activation_type = activation_map.find (fuse_activation);
216+
217+ PADDLE_ENFORCE_NE (
218+ activation_type,
219+ activation_map.end (),
220+ phi::errors::InvalidArgument (
221+ " Activation '%s' not found in oneDNN algorithms mapper" ,
222+ fuse_activation));
223+
224+ post_ops.append_eltwise (
225+ activation_scale, activation_type->second , fuse_alpha, fuse_beta);
226+ }
227+ }
228+
229+ float ComputeOutputScale (const ExecutionContext &ctx) {
230+ float alpha = ctx.HasAttr (" alpha" ) ? ctx.Attr <float >(" alpha" ) : 1 .0f ;
231+ if (ctx.HasAttr (" Scale_x" ) && ctx.HasAttr (" Scale_y" ) &&
232+ ctx.HasAttr (" Scale_out" )) {
233+ float scale_x = ctx.Attr <float >(" Scale_x" );
234+ float scale_y = ctx.Attr <float >(" Scale_y" );
235+ bool force_fp32_out = ctx.HasAttr (" force_fp32_output" )
236+ ? ctx.Attr <bool >(" force_fp32_output" )
237+ : false ;
238+ float scale_out = force_fp32_out ? 1 .f : ctx.Attr <float >(" Scale_out" );
239+ alpha *= scale_out / (scale_x * scale_y);
240+ }
241+ return alpha;
242+ }
243+
244+ dnnl::primitive_attr CreateMatmulAttrs (const ExecutionContext &ctx) {
245+ dnnl::primitive_attr matmul_attrs;
246+ dnnl::post_ops post_operations;
247+
248+ float scale_out = ComputeOutputScale (ctx);
249+ if (scale_out != 1 .0f ) {
250+ matmul_attrs.set_output_scales (0 , {scale_out});
251+ }
252+
253+ if (ctx.HasInput (" ResidualData" )) {
254+ auto *residual_data = ctx.Input <phi::DenseTensor>(" ResidualData" );
255+ auto residual_data_tz = phi::vectorize (residual_data->dims ());
256+ auto residual_data_md = memory::desc (residual_data_tz,
257+ phi::funcs::OneDNNGetDataType<OT>(),
258+ dnnl::memory::format_tag::any);
259+ post_operations.append_binary (dnnl::algorithm::binary_add,
260+ residual_data_md);
261+ if (ctx.HasAttr (" Scale_in_eltwise" )) {
262+ float sum_scale = scale_out / ctx.Attr <float >(" Scale_in_eltwise" );
263+ post_operations.append_sum (sum_scale);
264+ }
265+ }
266+
267+ AppendActivation (ctx, post_operations);
268+
269+ if (ctx.HasAttr (" fused_output_scale" )) {
270+ float scale_alpha = ctx.Attr <float >(" fused_output_scale" );
271+ post_operations.append_eltwise (
272+ 1.0 , dnnl::algorithm::eltwise_linear, scale_alpha, 0 .0f );
273+ }
274+
275+ matmul_attrs.set_post_ops (post_operations);
276+ return matmul_attrs;
277+ }
278+
279+ std::vector<int64_t > FakeTransposeStrides (
280+ const std::vector<int64_t > &matmul_out_dims) const {
281+ // fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
282+ // transpose axis are: {0, 2, 1, 3}
283+ std::vector<int64_t > transpose_axis = {0 , 2 , 1 , 3 };
284+ std::vector<int64_t > fake_strides (transpose_axis.size ());
285+ int ndims = static_cast <int >(transpose_axis.size ());
286+
287+ int total_stride = 1 ;
288+
289+ for (int i = ndims - 1 ; i >= 0 ; --i) {
290+ fake_strides[transpose_axis[i]] = total_stride;
291+ total_stride *= matmul_out_dims[transpose_axis[i]];
292+ }
293+
294+ return fake_strides;
295+ }
296+
297+ std::shared_ptr<memory> AcquireWeightsMemory (const phi::DenseTensor *input) {
298+ const YT *input_data = input->data <YT>();
299+ return this ->AcquireMemoryFromPrimitive (
300+ this ->fwd_pd_ ->weights_desc (),
301+ phi::funcs::to_void_cast<YT>(input_data));
302+ }
303+
304+ std::shared_ptr<dnnl::memory> AcquireDstMemory (phi::DenseTensor *output) {
305+ // We cannot use base AcquireDstMemory as it makes an allocation request
306+ // base on DST memory primitive size. This is fine in general, but in MatMul
307+ // we have primitive that covers only one batch of Data and then shift
308+ // pointer for every new batch. Hence phi::DenseTensor size is bigger that
309+ // dst memory primitive size. So would we request less memory that is there
310+ // and it triggers an assertion. So as there is no 'any' format here we can
311+ // leave default size of phi::DenseTensor as computed in ComputeInferShape
312+ OT *ptr = output->mutable_data <OT>(this ->place_ );
313+ return this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->dst_desc (), ptr);
314+ }
315+ };
316+
85317template <typename XT, typename YT, typename OT>
86318class MatMulMKLDNNHandler
87319 : public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
0 commit comments