|
46 | 46 | '__rsub__', |
47 | 47 | '__mul__', |
48 | 48 | '__rmul__', |
| 49 | + '__div__', |
49 | 50 | '__truediv__', |
| 51 | + '__rdiv__', |
50 | 52 | '__rtruediv__', |
51 | 53 | '__matmul__', |
52 | 54 | ] |
@@ -170,6 +172,9 @@ def _scalar_rsub_(var, value): |
170 | 172 | def _scalar_mul_(var, value): |
171 | 173 | return _scalar_elementwise_op_(var, value, 0.0) |
172 | 174 |
|
| 175 | + def _scalar_div_(var, value): |
| 176 | + return _scalar_elementwise_op_(var, 1.0 / value, 0.0) |
| 177 | + |
173 | 178 | # for binary operator such as elementwise, compare |
174 | 179 | def _binary_creator_(method_name, |
175 | 180 | op_type, |
@@ -200,10 +205,7 @@ def __impl__(self, other_var): |
200 | 205 | if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: |
201 | 206 | self = astype(self, 'float32') |
202 | 207 | # here use `scale` replace `elementwise` to get better performance |
203 | | - # but only +, -, * can use this method |
204 | | - # NOTE(chentianyu03): / can not use `scale` method,because the result of |
205 | | - # `scale` method (self*(1/other_var)) do not exactly equal with the result |
206 | | - # of `elementwise_div` method. |
| 208 | + # but only +, -, *, / can use this method |
207 | 209 | if scalar_method is not None: |
208 | 210 | return scalar_method(self, other_var) |
209 | 211 | else: |
@@ -296,8 +298,12 @@ def __impl__(self, other_var): |
296 | 298 | ## a*b == b*a. Do not need to reverse explicitly |
297 | 299 | ('__rmul__', |
298 | 300 | _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), |
| 301 | + ('__div__', _binary_creator_('__div__', 'elementwise_div', False, |
| 302 | + _scalar_div_)), |
299 | 303 | ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', |
300 | | - False, None)), |
| 304 | + False, _scalar_div_)), |
| 305 | + ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, |
| 306 | + None)), |
301 | 307 | ('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, |
302 | 308 | None)), |
303 | 309 | ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, |
|
0 commit comments