loader: freeze MPS graph cache after warmup for graph DataLoaders#10689
Open
anagnorisis2peripeteia wants to merge 4 commits intopyg-team:masterfrom
Open
loader: freeze MPS graph cache after warmup for graph DataLoaders#10689anagnorisis2peripeteia wants to merge 4 commits intopyg-team:masterfrom
anagnorisis2peripeteia wants to merge 4 commits intopyg-team:masterfrom
Conversation
b04e683 to
0ac92dc
Compare
7736b41 to
a5b69c3
Compare
60da048 to
2f3baae
Compare
PyG's Batch.from_data_list() concatenates variable-size graphs into a single tensor keyed by (sum_nodes, sum_edges). With a shuffled DataLoader every batch produces a unique shape, so MPSGraph compiles a new computation graph for each batch and caches it indefinitely. On a 10k-molecule dataset with batch_size=32 over 100 epochs this produces ~31,000 unique shapes at ~55 KB each — ~1.7 GB of graph cache that is never reused. DataLoader now calls torch.mps.freeze_graph_cache() automatically after a short warmup phase. The warmup length is derived from the dataset's node and edge count distributions using the effective 2-D shape space formula: eff_2d = sqrt(BS) * std(node_counts) * sqrt(BS) * std(edge_counts) * 2π freeze_at = max(5, int(sqrt(eff_2d))) After freeze the cache is read-only: hits are served from compiled graphs, rare shapes compile on-the-fly and are immediately discarded — no unbounded growth. Benchmark (M1 Max, PyTorch 2.13, ZINC subset, BS=32, training loop): always (baseline) +1370 MB RSS 43 ms/iter → ~1.7 GB/100 epochs freeze_after_warmup +0 MB RSS 70 ms/iter (1.61x, flat) clear_per_iter (prev) +0 MB RSS 112 ms/iter (2.59x, flat) never +0 MB RSS 144 ms/iter (3.33x, flat) freeze_after_warmup is 38% faster than clearing after each iteration and 52% faster than never-cache, because it preserves compiled graphs for shapes seen during warmup through the full forward+backward+step cycle. The feature is a no-op when: - MPS is unavailable - torch.mps.freeze_graph_cache is absent (PyTorch < 2.13) - the dataset does not expose num_nodes / num_edges statistics - node or edge counts are uniform (std == 0, caching already optimal) Depends on: pytorch/pytorch#182648
c02b41d to
27baaf9
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Depends on: pytorch/pytorch#182648
Related issues:
The problem
PyG's
Batch.from_data_list()concatenates variable-size graphs: the shape key is(sum_nodes, sum_edges)— a 2-D combinatorial space. With shuffled DataLoaders (required for SGD convergence), batch compositions never repeat and every batch triggers a new MPSGraph compilation cached indefinitely:A 100-epoch training run on ZINC (10k molecules, BS=32) produces ~31,000 unique batch shapes. At ~55 KB/entry that is ~1.7 GB of graph cache that is never reused — a known source of OOM on Apple Silicon (pytorch/pytorch#77753, pytorch/pytorch#164299).
The fix
After a short warmup that populates the cache with the most common shapes,
freeze_graph_cache()makes the cache read-only. Hits continue to be served; rare shapes compile on-the-fly and are immediately discarded. No unbounded growth.Warmup formula (derived from CLT on the 2-D shape distribution):
Computed from actual dataset statistics in
DataLoader.__init__. Not hand-tuned.Usage
Zero user changes required — the freeze happens automatically:
Benchmark results
MacBook Pro M1 Max, PyTorch 2.13.0a0, ZINC subset (12k molecules, BS=32).
Training loop (forward + backward + optimizer.step).
eff_2d=9,536,freeze_at=50— computed from the ZINC node/edge distribution.freeze_after_warmupis 38% faster than clearing after each iteration and 52% faster than never-cache. Theclear_per_iteroverhead comes from destroying forward→backward graph reuse every iteration;freezepreserves compiled graphs for warmup-seen shapes through the full forward+backward cycle.Reproduce with:
python3 benchmark/mps_graph_cache.pyChanges
torch_geometric/loader/dataloader.py: add_mps_freeze_at()helper and overrideDataLoader.__iter__to calltorch.mps.freeze_graph_cache()at the computed warmup stepbenchmark/mps_graph_cache.py: self-contained ZINC training benchmark reproducing the above resultsThe feature is a strict no-op when:
torch.mps.freeze_graph_cacheis absent (PyTorch < 2.13)num_nodes/num_edgesTest Plan
always→ +1370 MB RSS over 200 iters;freeze_after_warmup→ +0.0 MB RSS (flat)freezeis 1.61× vsalwaysbaseline (vs 2.59× forclear_per_iter)benchmark/mps_graph_cache.pyreproduces these resultscc @rusty1s @mananshah99 @akihironitta