diff --git a/torchao/float8/README.md b/torchao/float8/README.md index b9b40d7e41..34dee659f8 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -34,21 +34,12 @@ m = nn.Sequential( x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) optimizer = torch.optim.SGD(m.parameters(), lr=0.1) -# optional: filter modules from being eligible for float8 conversion -def module_filter_fn(mod: torch.nn.Module, fqn: str): - # don't convert the last module - if fqn == "1": - return False - # don't convert linear modules with weight dimensions not divisible by 16 - if isinstance(mod, torch.nn.Linear): - if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: - return False - return True - -# convert specified `torch.nn.Linear` modules to `Float8Linear` -convert_to_float8_training(m, module_filter_fn=module_filter_fn) - -# enable torch.compile for competitive performance +# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute +# and optionally distributed communications in float8 +convert_to_float8_training(m) + +# enable torch.compile to generate fused kernels for float8 scaling and casting, +# which improves performance m = torch.compile(m) # toy training loop @@ -94,7 +85,8 @@ config = Float8LinearConfig( # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior convert_to_float8_training(m, config=config) -# enable torch.compile for competitive performance +# enable torch.compile to generate fused kernels for float8 scaling and casting, +# which improves performance m = torch.compile(m) # toy training loop