Skip to content

Commit 43876e8

Browse files
authored
make stop_gradient=True for random op in static graph (#33959)
1 parent 3629bf4 commit 43876e8

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/paddle/tensor/random.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def bernoulli(x, name=None):
7474
dtype=x.dtype) # maybe set out to int32 ?
7575
helper.append_op(
7676
type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={})
77+
out.stop_gradient = True
7778
return out
7879

7980

@@ -143,6 +144,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
143144
outputs={'Out': out},
144145
attrs={'num_samples': num_samples,
145146
'replacement': replacement})
147+
out.stop_gradient = True
146148
return out
147149

148150

@@ -514,6 +516,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
514516
helper.append_op(
515517
type="uniform_random", inputs=inputs, attrs=attrs,
516518
outputs={"Out": out})
519+
out.stop_gradient = True
517520
return out
518521

519522

@@ -615,6 +618,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
615618
out = helper.create_variable_for_type_inference(dtype=dtype)
616619
helper.append_op(
617620
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
621+
out.stop_gradient = True
618622
return out
619623

620624

0 commit comments

Comments
 (0)