@@ -154,7 +154,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
154154 const auto & place = in_tensor.place ();
155155 const auto & key = GetKeyFromPlace (place);
156156
157- if (!calc_event_) {
157+ if (!calc_event_ ||
158+ (place_to_comm_ctx_.find (key) == place_to_comm_ctx_.end ())) {
158159 CreateBKCLEnvCache (place, key);
159160 }
160161
@@ -170,6 +171,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
170171 fn (out_tensor, in_tensor, comm_ctx->bkcl_context (), bkcl_stream);
171172
172173 if (!use_calc_stream) {
174+ PADDLE_ENFORCE_NOT_NULL (
175+ comm_ctx.get (), platform::errors::Fatal (" comm context is nullptr." ));
173176 task->comm_event_ ->Record (*comm_ctx.get ());
174177 }
175178
@@ -369,6 +372,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
369372 1 ,
370373 platform::errors::InvalidArgument (
371374 " BKCL only support single tensor collective communication." ));
375+ PADDLE_ENFORCE_EQ (
376+ CheckTensorsInXPUPlace (in_tensors),
377+ true ,
378+ platform::errors::InvalidArgument (" All inputs should be in XPUPlace." ));
372379 return Collective (
373380 &out_tensors[0 ],
374381 in_tensors[0 ],
@@ -406,6 +413,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
406413 1 ,
407414 platform::errors::InvalidArgument (
408415 " BKCL only support single tensor collective communication." ));
416+ PADDLE_ENFORCE_EQ (
417+ CheckTensorsInXPUPlace (in_tensors),
418+ true ,
419+ platform::errors::InvalidArgument (" All inputs should be in XPUPlace." ));
409420 return Collective (
410421 &out_tensors[0 ],
411422 in_tensors[0 ],
@@ -442,6 +453,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
442453 1 ,
443454 platform::errors::InvalidArgument (
444455 " BKCL only support single tensor collective communication." ));
456+ PADDLE_ENFORCE_EQ (
457+ CheckTensorsInXPUPlace (in_tensors),
458+ true ,
459+ platform::errors::InvalidArgument (" All inputs should be in XPUPlace." ));
445460
446461 return Collective (
447462 &out_tensors[0 ],
@@ -481,6 +496,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
481496 1 ,
482497 platform::errors::InvalidArgument (
483498 " BKCL only support single tensor collective communication." ));
499+ PADDLE_ENFORCE_EQ (
500+ CheckTensorsInXPUPlace (in_tensors),
501+ true ,
502+ platform::errors::InvalidArgument (" All inputs should be in XPUPlace." ));
484503
485504 return Collective (
486505 &out_tensors[0 ],
@@ -518,6 +537,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
518537 1 ,
519538 platform::errors::InvalidArgument (
520539 " BKCL only support single tensor collective communication." ));
540+ PADDLE_ENFORCE_EQ (
541+ CheckTensorsInXPUPlace (in_tensors),
542+ true ,
543+ platform::errors::InvalidArgument (" All inputs should be in XPUPlace." ));
521544 PADDLE_ENFORCE_EQ (
522545 CheckTensorsInXPUPlace (out_tensors),
523546 true ,
0 commit comments