Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
run_tests,
)
from torchao.dtypes.nf4tensor import (
NF4Tensor,
linear_nf4,
to_nf4,
_INNER_TENSOR_NAMES_FOR_SHARDING,
Expand Down Expand Up @@ -270,6 +271,14 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):

torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data)

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_empty_like(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.rand(input_size))
new_tensor = torch.empty_like(nf4_tensor, device=torch.device("cpu"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to make sure device arg works, can you make sure the nf4_tensor is on gpu?

self.assertTrue(isinstance(new_tensor, NF4Tensor))
self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU
self.assertEqual(new_tensor.size(), nf4_tensor.size())


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def clone(func, *args, **kwargs):
@implements(
[
aten.detach.default,
aten.empty_like.default,
]
)
def nf4_detach(aten_op, args, kwargs=None):
Expand Down