Skip to content

Commit 771aef6

Browse files
committed
add type_as
1 parent 51c12fb commit 771aef6

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def astype(self: Tensor, dtype: DTypeLike) -> Tensor:
104104

105105
return _C_ops.cast(self, dtype)
106106

107+
def type_as(self: Tensor, other: Tensor) -> Tensor:
108+
return self.astype(other.dtype)
109+
107110
def _scalar_elementwise_op_(
108111
var: Tensor, scale: float, bias: float
109112
) -> Tensor:
@@ -225,6 +228,7 @@ def _mT_(var: Tensor) -> Tensor:
225228
('__len__', _len_),
226229
('__index__', _index_),
227230
('astype', astype),
231+
('type_as', type_as),
228232
('dim', dim),
229233
('ndimension', ndimension),
230234
('ndim', _ndim),

python/paddle/base/layers/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ def astype(self, dtype):
382382
out.stop_gradient = self.stop_gradient
383383
return out
384384

385+
def type_as(self, other):
386+
return self.astype(other.dtype)
387+
385388
@static_only
386389
def append(self, var):
387390
"""
@@ -799,6 +802,7 @@ def to_dense(var):
799802
('__neg__', _neg_),
800803
('__abs__', _abs_),
801804
('astype', astype),
805+
('type_as', type_as),
802806
('cpu', cpu),
803807
('cuda', cuda),
804808
('place', place),

python/paddle/pir/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ def astype(self, dtype):
370370

371371
return _C_ops.cast(self, dtype)
372372

373+
def type_as(self, other):
374+
return self.astype(other.dtype)
375+
373376
def _scalar_add_(var, value):
374377
return paddle.scale(var, 1.0, value)
375378

@@ -1109,6 +1112,7 @@ def register_hook(self, hook):
11091112
('ndimension', ndimension),
11101113
('ndim', _ndim),
11111114
('astype', astype),
1115+
('type_as', type_as),
11121116
('size', _size_),
11131117
('T', _T_),
11141118
('mT', _mT_),

0 commit comments

Comments
 (0)