Skip to content

Why function sum can change the dtype of ndl.Tensor #11

@xnuohz

Description

@xnuohz

I met the following error when testing sgd

@data.setter
    def data(self, value):
        assert isinstance(value, Tensor)
>       assert value.dtype == self.dtype, "%s %s" % (
            value.dtype,
            self.dtype,
        )
E       AssertionError: float64 float32

Then I found 1 line in the function compute_gradient_of_variables will cause this error

node.grad = sum(node_to_output_grads_list[node])

I change it and things go right

node_grads = node_to_output_grads_list[node]
node.grad = node_grads[0] if len(node_grads) == 1 else sum(node_grads)

The following dtype in pdb is wired. Maybe I was wrong.

(Pdb) node_grads
[needle.Tensor(1.0)]
(Pdb) node_grads[0].dtype
dtype('float32')
(Pdb) sum(node_grads).dtype
dtype('float64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions