diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index aef7ef62a3b..9c1a4b89cee 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -588,7 +588,27 @@ class ArgWhereFunctor { const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype"); attrs.SetAllAttrs(dtype->data_type()); +#ifdef WITH_NPU + auto device_type = DeviceType::kCPU; + if (x->is_global()) { + device_type = JUST(x->parallel_desc())->device_type(); + } else { + device_type = JUST(x->device())->enum_type(); + } + if (device_type == DeviceType::kNPU) { + // NOTE: use cpu argwhere when device="npu" + auto cpu_tensor = JUST(one::functional::To(x, "cpu")); + auto result = JUST(OpInterpUtil::Dispatch(*op_, {cpu_tensor}, attrs)); + for (int i = 0; i < result->size(); ++i) { + (*result)[i] = JUST(one::functional::To((*result)[i], "npu")); + } + return result; + } else { + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } +#else return OpInterpUtil::Dispatch(*op_, {x}, attrs); +#endif // WITH_NPU } private: