-
Notifications
You must be signed in to change notification settings - Fork 70
Fix decomposeLinearWithBias to shard all created tensorviews
#5563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| import torch | ||
|
|
||
| import nvfuser_direct as nvfuser | ||
| from nvfuser_direct import DataType, FusionDefinition | ||
| from nvfuser_direct import DataType, FusionDefinition, PythonProfiler | ||
|
|
||
|
|
||
| # Avoid doing this when possible. This test started to exist before nvFuser | ||
|
|
@@ -197,50 +197,61 @@ def _multidevice_schedule(fd: FusionDefinition): | |
| def test_linear_reduce_scatter(multidevice_direct_test): | ||
| d = multidevice_direct_test.size | ||
| mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) | ||
| e = 768 | ||
| b, s, e = 3, 5, 7 | ||
|
|
||
| def _definition(fd: FusionDefinition): | ||
| inp = fd.define_tensor([-1, -1, d * e]) | ||
| weight = fd.define_tensor([e, d * e]) | ||
| out = fd.ops.linear(inp, weight, None) | ||
| inp = fd.define_tensor([-1, d * s, d * e], dtype=DataType.BFloat16) | ||
| weight = fd.define_tensor([-1, d * e], dtype=DataType.BFloat16) | ||
| bias = fd.define_tensor([e], dtype=DataType.BFloat16) | ||
| out = fd.ops.linear(inp, weight, bias) | ||
| fd.add_output(out) | ||
|
|
||
| def _multidevice_schedule(fd: FusionDefinition): | ||
| inp, weight = fd.fusion.inputs() | ||
| inp, weight, bias = fd.fusion.inputs() | ||
| (out,) = fd.fusion.outputs() | ||
| for t in [inp, weight, out]: | ||
| t.set_device_mesh(mesh) | ||
| t.split(-1, d, inner_split=False) | ||
| t.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) | ||
| bias.set_device_mesh(mesh) | ||
| for tv in [inp, weight, out]: | ||
| tv.set_device_mesh(mesh) | ||
| tv.split(-1, d, inner_split=False) | ||
| tv.axis(-2).parallelize(nvfuser.ParallelType.mesh_x) | ||
|
|
||
| # Scatter | ||
| out.split(1, d, inner_split=False) | ||
| out.axis(1).parallelize(nvfuser.ParallelType.mesh_x) | ||
|
|
||
| torch.cuda.set_device(multidevice_direct_test.local_rank) | ||
|
|
||
| # set b=1 as a temporary fix for the test to pass. | ||
| # TODO: set b>1 once reduce scatter is fixed. | ||
| b, s = 2, 1024 | ||
| unsharded_inp = torch.randn(b, s, d * e) | ||
| unsharded_weight = torch.randn(e, d * e) | ||
|
|
||
| unsharded_inp = torch.randint(-2, 3, (b, d * s, d * e)).to(torch.bfloat16) | ||
| unsharded_weight = torch.randint(-2, 3, (e, d * e)).to(torch.bfloat16) | ||
| bias = torch.randint(-2, 3, (e,)).to(torch.bfloat16) | ||
| inp = multidevice_direct_test.shard_tensor(unsharded_inp, -1, mesh) | ||
| weight = multidevice_direct_test.shard_tensor(unsharded_weight, -1, mesh) | ||
|
|
||
| with FusionDefinition() as fd: | ||
| _definition(fd) | ||
| _multidevice_schedule(fd) | ||
|
|
||
| (out,) = fd.execute([inp, weight]) | ||
| with PythonProfiler() as prof: | ||
| (out,) = fd.execute([inp, weight, bias.cuda()]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this synchronize? Could we miss kernels?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this what you are referring to?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a difference between cudaStreamSynchronize and cudaDeviceSynchronize though. The former blocks the stream and the latter blocks the host.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. I assumed FusionProfiler/PythonProfiler synchronize at start but not on stop. So I will add an explicit call here. Note for myself: See if FusionProfiler should synchronize before reading data. |
||
|
|
||
| unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None) | ||
| # rtol is the same as the default for fp32. atol is slightly increased. | ||
| # Only one reduce scatter kernel should be scheduled. | ||
| assert ( | ||
| len( | ||
| [ | ||
| kp | ||
| for kp in prof.profile.kernel_profiles | ||
| if kp.scheduler == "communication" | ||
| ] | ||
| ) | ||
| == 1 | ||
| if d > 1 | ||
| else 0 | ||
| ) | ||
|
|
||
| unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, bias) | ||
| torch.testing.assert_close( | ||
| out, | ||
| multidevice_direct_test.shard_tensor(unsharded_out, 1, mesh), | ||
| rtol=1.3e-6, | ||
| atol=1e-3, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caused the wrong scheduler name in profiler output.