Skip to content

Commit 1e04216

Browse files
authored
[Metax] fix illegal address access error in test_momentum_op (#12)
* [Metax] fix illegal address access error in test_momentum_op
1 parent 7964c35 commit 1e04216

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

backends/metax_gpu/patch/tmp/mixed_vector.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ class MixVector {
386386

387387
// the unify method to access CPU or CUDA data. immutable.
388388
const T *Data(phi::Place place) const {
389-
if (place.GetType() == phi::AllocationType::GPU) {
389+
if (place.GetType() == phi::AllocationType::GPU ||
390+
place.GetType() == phi::AllocationType::CUSTOM) {
390391
return CUDAData(place);
391392
} else {
392393
return data();
@@ -395,7 +396,8 @@ class MixVector {
395396

396397
// the unify method to access CPU or CUDA data. mutable.
397398
T *MutableData(phi::Place place) {
398-
if (place.GetType() == phi::AllocationType::GPU) {
399+
if (place.GetType() == phi::AllocationType::GPU ||
400+
place.GetType() == phi::AllocationType::CUSTOM) {
399401
return CUDAMutableData(place);
400402
} else {
401403
return data();

0 commit comments

Comments
 (0)