Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def clone(func, *args, **kwargs):
@implements(
[
aten.detach.default,
aten.empty_like.default,
]
)
def nf4_detach(aten_op, args, kwargs=None):
Expand Down Expand Up @@ -956,6 +957,13 @@ def decorator(func):
@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
tensor = args[0]
if len(args) <= 1:
if "device" in kwargs:
# Tensor.to(device, non_blocking)
device = kwargs["device"]
updated_attrs = call_from_inner_tensors(tensor, "to", args[1:], kwargs)
updated_attrs["device"] = device
return NF4Tensor(*construct_nf4_args(tensor, updated_attrs))
if isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return tensor.get_original_weight().to(*args[1:], **kwargs)
Expand Down