-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
I met the following error when testing sgd
@data.setter
def data(self, value):
assert isinstance(value, Tensor)
> assert value.dtype == self.dtype, "%s %s" % (
value.dtype,
self.dtype,
)
E AssertionError: float64 float32
Then I found 1 line in the function compute_gradient_of_variables will cause this error
node.grad = sum(node_to_output_grads_list[node])
I change it and things go right
node_grads = node_to_output_grads_list[node]
node.grad = node_grads[0] if len(node_grads) == 1 else sum(node_grads)
The following dtype in pdb is wired. Maybe I was wrong.
(Pdb) node_grads
[needle.Tensor(1.0)]
(Pdb) node_grads[0].dtype
dtype('float32')
(Pdb) sum(node_grads).dtype
dtype('float64')
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels