@@ -599,17 +599,8 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
599599 const std::string& uniq_name)
600600 : platform::MKLDNNHandlerT<T, dnnl::binary>(
601601 dev_ctx, engine, cpu_place,
602- platform::CreateKey (
603- dev_ctx, framework::vectorize(x->dims ()), uniq_name,
604- (algo == dnnl::algorithm::binary_mul ? " M" : " " ))) {
605- // bradcasting combined with in-place may require
606- auto rankdiff = x->dims ().size () - y->dims ().size ();
607- if (rankdiff > 0 ) {
608- auto suffix = std::to_string (rankdiff);
609- this ->key_ += suffix;
610- this ->key_common_ += suffix;
611- }
612-
602+ platform::CreateKey (dev_ctx, framework::vectorize(x->dims ()),
603+ uniq_name)) {
613604 if (!this ->isCached ()) {
614605 PADDLE_ENFORCE_EQ (
615606 x->layout (), DataLayout::kMKLDNN ,
@@ -629,18 +620,24 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
629620 const auto src_y_tz = framework::vectorize (y->dims ());
630621 // if output tensor(z) is nullptr then we are computing into oneDNN
631622 // managed buffer
632- const auto dst_tz =
633- (z == nullptr ) ? src_x_tz : framework::vectorize (z->dims ());
623+ auto rankdiff = x->dims ().size () - y->dims ().size ();
624+ const auto dst_tz = (z == nullptr ) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
625+ : framework::vectorize (z->dims ());
634626
635- const auto src0_md = dnnl::memory::desc (
627+ auto src0_md = dnnl::memory::desc (
636628 src_x_tz, platform::MKLDNNGetDataType<T>(), x->format ());
637629 auto src1_md = dnnl::memory::desc (
638630 src_y_tz, platform::MKLDNNGetDataType<T>(), y->format ());
639- if (rankdiff > 0 ) {
631+ if (rankdiff > 0 ) { // Second input is of smaller rank than first
640632 std::vector<int64_t > dims1_ex (rankdiff, 1 );
641633 dims1_ex.insert (next (dims1_ex.begin (), (axis == -1 ? rankdiff : axis)),
642634 src_y_tz.begin (), src_y_tz.end ());
643635 src1_md = src1_md.reshape (dims1_ex);
636+ } else if (rankdiff < 0 ) { // First input is of smaller than second
637+ std::vector<int64_t > dims0_ex (-rankdiff, 1 );
638+ dims0_ex.insert (next (dims0_ex.begin (), (axis == -1 ? -rankdiff : axis)),
639+ src_x_tz.begin (), src_x_tz.end ());
640+ src0_md = src0_md.reshape (dims0_ex);
644641 }
645642 const auto dst_md = memory::desc (dst_tz, platform::MKLDNNGetDataType<T>(),
646643 MKLDNNMemoryFormat::any);
0 commit comments