feat: improve how device switch is handled between the metric device and the input tensors device#3043
Conversation
… on separate devices.
|
Thanks a lot the PR @MarcBresson !
Yes, that's correct and as the code overwrites and usually the data has consistently the same device, the actual code is working. Concerning this PR's code, I think it would be good to follow below logic:
What do you think ?
There was a slow down in early pytorch version, we should measure that now. Please use torch utils benchmark to get some numbers : https://github.com/vfdev-5/pth-inductor-dev/blob/eb01fa071a2337c7037e8a7e961b2147c5fc8b42/perf_flip.py#L52-L61 |
import torch
import torch.utils.benchmark as benchmark
def rand_to_device(shape: tuple, created_on_device, to_device):
on_device = torch.rand(shape, device=created_on_device)
if to_device is not None:
t = on_device.to(device=to_device)
results = []
min_run_time = 5
shape = (12, 4, 256, 256)available_devices = [torch.device("cuda:0"), torch.device("cpu")]
for from_device in available_devices:
for to_device in available_devices + [None]:
print(f"{from_device} to {to_device} measurements")
results.append(
benchmark.Timer(
stmt=f"fn({shape}, created_on_device, to_device)",
globals={
"fn": rand_to_device,
"created_on_device": from_device,
"to_device": to_device,
},
num_threads=torch.get_num_threads(),
label=f"{from_device} to {to_device} measurements"
).blocked_autorange(min_run_time=min_run_time)
)I had an error when calling results
>>> cuda:0 to cuda:0 measurements
Median: 113.61 us
IQR: 0.70 us (113.34 to 114.04)
430 measurements, 100 runs per measurement, 6 threads
>>> cuda:0 to cpu measurements
Median: 4.95 ms
IQR: 0.14 ms (4.89 to 5.03)
11 measurements, 100 runs per measurement, 6 threads
>>> cuda:0 to None measurements
Median: 110.68 us
IQR: 0.30 us (110.45 to 110.75)
46 measurements, 1000 runs per measurement, 6 threads
>>> cpu to cuda:0 measurements
Median: 23.79 ms
IQR: 6.15 ms (22.10 to 28.24)
20 measurements, 10 runs per measurement, 6 threads
WARNING: Interquartile range is 25.8% of the median measurement.
This suggests significant environmental influence.
>>> cpu to cpu measurements
Median: 21.80 ms
IQR: 3.26 ms (21.11 to 24.37)
21 measurements, 10 runs per measurement, 6 threads
WARNING: Interquartile range is 14.9% of the median measurement.
This could indicate system fluctuation.
>>> cpu to None measurements
Median: 20.86 ms
IQR: 5.49 ms (19.88 to 25.37)
21 measurements, 10 runs per measurement, 6 threads
WARNING: Interquartile range is 26.3% of the median measurement.
This suggests significant environmental influence.It seems like there is a really slight slow down (we must compare |
|
I think that is a great suggestion, always performing the computation on GPU if either one of the kernel or input tensor in on it. |
If either one of the metric device or the update input device is a GPU, this commit will put the other one on GPU.
|
Here is how you can make it. I changed a bit the perf code to measure import torch
import torch.utils.benchmark as benchmark
results = []
min_run_time = 5
shape = (12, 4, 256, 256)
available_devices = [torch.device("cuda"), torch.device("cpu")]
for shape in [(12, 4, 256, 256), (8, 3, 512, 512)]:
for from_device in available_devices:
data = torch.rand(shape, device=from_device)
for to_device in available_devices:
print(f"{shape} -> {from_device} to {to_device} measurements")
results.append(
benchmark.Timer(
stmt=f"data.to(to_device, non_blocking=False)",
globals={
"data": data,
"to_device": to_device,
},
description=f"{from_device} to {to_device}",
num_threads=torch.get_num_threads(),
label="Device to device",
sub_label=f"{tuple(shape)}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()On my infra it gives: I agree that it is not a big deal but we also have to think that in the real application |
|
You should include a With this We can see that creating an element on cuda then moving it to the same cuda device takes ~800ns while not trying to move it only takes ~250ns. The figures there are strangely very different from what I had on my first run #3043 (comment)... --- EDIT --- |
according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to and probably do nothing. So, we are observing just calling a python function for 250 ns (which is reasonable) |
|
(there are bugs, I am writing test and will correct them) |
The comparison with self._device was not possible because it
can be created with `torch.device("cuda")` which is not equal
to `torch.device("cuda:0")` which is the device of a tensor
created with `torch.device("cuda")`. This change will have
a bigger performance hit when self._kernel is not on the same
device as y_pred as it will need to be moved onto y_pred's
device every time update() is called.
|
I fixed everything. One thing to note is the (assumed) performance hit when self._kernel is not on the same device as y_pred as it will need to be moved onto y_pred's device every time update() is called. |
|
I am writing new tests for the variable channel size, will push soon will all the changes that you suggested |
|
oh no conflicts, what have I done ? |
Co-authored-by: vfdev <[email protected]>
vfdev-5
left a comment
There was a problem hiding this comment.
LGTM, thanks @MarcBresson !
@vfdev-5 I investigated on the weird code, and as it turns out, kernel could never be of size > 2 (it is only computed in _uniform() or _gaussian() which both output 2 dim tensors).
I wrote this little fix with a warning if the update tensors are not on the device device than the metric.
Do you know if calling .to(device) on a tensor that is already on the device will cause slow downs ?
Check list: