@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
7171 return reduced;
7272}
7373
74+ static const std::vector<int > GetDimsForKey (
75+ const std::vector<const Tensor*>& inputs) {
76+ auto dims_key = paddle::framework::vectorize<int >(inputs[0 ]->dims ());
77+ for (auto it = std::next (inputs.begin ()); it != inputs.end (); ++it) {
78+ dims_key.push_back ((*it)->dims ()[0 ]);
79+ }
80+ return dims_key;
81+ }
82+
7483template <typename T>
7584class ConcatPrimitiveFactory {
7685 public:
@@ -134,6 +143,8 @@ template <typename T>
134143class ConcatMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
135144 public:
136145 void Compute (const paddle::framework::ExecutionContext& ctx) const override {
146+ // If any of the multiple inputs of concat has an input size of 0, the
147+ // actual size of the multi_input will change
137148 auto multi_input = ReduceMultiInput (ctx.MultiInput <Tensor>(" X" ));
138149 EnforceLayouts (multi_input);
139150 Tensor* output = ctx.Output <Tensor>(" Out" );
@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
156167 paddle::framework::ToMKLDNNDataType (multi_input[0 ]->type ());
157168
158169 ConcatPrimitiveFactory<T> prim_creator;
159- // If one of the multiple inputs of concat has an input size of 0, the
160- // actual size of the multi_input will change
161- std::string key = platform::CreateKey (
162- dev_ctx, paddle::framework::vectorize<int >(multi_input[0 ]->dims ()),
163- multi_input.size (), ctx.OutputName (" Out" ), dt,
164- platform::ThreadIDasStr ());
170+ std::string key =
171+ platform::CreateKey (dev_ctx, GetDimsForKey (multi_input),
172+ multi_input.size (), ctx.OutputName (" Out" ), dt);
165173 key = platform::ExtendKeyWithThreadInfoIfNeeded (dev_ctx, key);
166174
167175 const std::string key_prim = key + " @concat_p" ;
0 commit comments