Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ def reset_profile(self):
Adds or resets the extra attributes.
"""

# ensure no residual operator statistics remain from a prior profiling
# session. In certain workflows, patched operator wrappers can emit
# bookkeeping information even after hooks are removed (for example when
# modules invoke autograd-generated graphs). Clearing the global
# collectors prevents stale entries from leaking into the next
# invocation which would otherwise inflate the aggregated FLOP/MAC
# counts when `start_profile` is called repeatedly.
module_flop_count.clear()
module_mac_count.clear()

def get_param_count_and_ep(param):
"""
Return the number of parameters in the layer, whether the layer is an MoE layer,
Expand Down Expand Up @@ -227,6 +237,8 @@ def remove_profile_attrs(module):
del module.__duration__

self.model.apply(remove_profile_attrs)
module_flop_count.clear()
module_mac_count.clear()
logger.info("Flops profiler finished")

def get_total_flops(self, as_string=False):
Expand Down
6 changes: 6 additions & 0 deletions docs/_tutorials/flops-profiler.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ In the summary profile, the DeepSpeed Flops Profiler outputs the number of param

The DeepSpeed Flops Profiler also measures significant modules at different model depths (aggregated profile) and module-specific profile in the model architecture (detailed profile). Using these profiles, DeepSpeed users can understand how each layer or submodule contributes to the overall model complexity/performance. Then users can adjust or refactor the model design to improve performance. For example, using the profiler, DeepSpeed users can quantitatively tell if stacking smaller layers is lighter or more performant than having bigger ones. The aggregated and detailed profiles also allow users to quickly identify bottleneck modules. In the BERT-Large example above, using the DeepSpeed Flops Profiler, we find that BertLayer is the most significant layer and contains quite a few dropout, softmax, and layer norm along with linear modules. These modules are not heavy in flops and would trigger many GPU kernel invocations and create excessive read/write requests to memory. The pattern shown in the detailed profile suggests this is a perfect match for kernel fusion, and we developed fused transformer-kernels to reduce data movement (see [DeepSpeedBert](/tutorials/bert-pretraining)). After applying our optimizations, we see a 25% improvement in FLOPS per GPU and overall training samples/second in the DeepSpeed Flops Profiler output.

### Resolving stale FLOP/MAC counters between runs

Prior to DeepSpeed 0.15.4, repeatedly calling `start_profile()`/`stop_profile()` inside a single Python process could dramatically over-report the FLOP/MAC totals. The profiler relies on two global collectors (`module_flop_count` and `module_mac_count`) that receive contributions from monkey-patched functional wrappers. When profiling stopped, some autograd-generated graphs could still be executing and continue appending entries to those collectors even though the main module hooks had already been removed. The next call to `start_profile()` therefore began with the collectors pre-populated, causing the first modules to absorb a large amount of stale statistics and inflating the aggregated totals exposed through `get_total_flops()` or `get_total_macs()`.

The fix clears both lists whenever a profiling session is reset or fully ended. Each run now starts from a clean state, so the global collectors only contain the operations triggered within the current `start_profile()`/`stop_profile()` block. If you previously observed the totals growing across identical inputs, upgrade to the latest DeepSpeed release to pick up this fix.

The DeepSpeed Flops Profiler can be used with the DeepSpeed runtime without any user code change or be used independently from DeepSpeed as a standalone package. When using DeepSpeed for model training, the profiler can be enabled in the DeepSpeed [configuration file](/docs/config-json/#flops-profiler). As a standalone package, the profiler API can be used in both training and inference code. The DeepSpeed profiler is still under active development and includes just initial features. Stay connected for more exciting features to be added soon.

## Flops Measurement
Expand Down
Loading