Skip to content

loader: freeze MPS graph cache after warmup for graph DataLoaders#10689

Open
anagnorisis2peripeteia wants to merge 4 commits intopyg-team:masterfrom
anagnorisis2peripeteia:mps-graph-cache-policy
Open

loader: freeze MPS graph cache after warmup for graph DataLoaders#10689
anagnorisis2peripeteia wants to merge 4 commits intopyg-team:masterfrom
anagnorisis2peripeteia:mps-graph-cache-policy

Conversation

@anagnorisis2peripeteia
Copy link
Copy Markdown

@anagnorisis2peripeteia anagnorisis2peripeteia commented May 6, 2026

Summary

Depends on: pytorch/pytorch#182648

Related issues:

The problem

PyG's Batch.from_data_list() concatenates variable-size graphs: the shape key is (sum_nodes, sum_edges) — a 2-D combinatorial space. With shuffled DataLoaders (required for SGD convergence), batch compositions never repeat and every batch triggers a new MPSGraph compilation cached indefinitely:

P(repeat) ≈ 1/C(N, batch_size) ≈ 0

A 100-epoch training run on ZINC (10k molecules, BS=32) produces ~31,000 unique batch shapes. At ~55 KB/entry that is ~1.7 GB of graph cache that is never reused — a known source of OOM on Apple Silicon (pytorch/pytorch#77753, pytorch/pytorch#164299).

The fix

After a short warmup that populates the cache with the most common shapes, freeze_graph_cache() makes the cache read-only. Hits continue to be served; rare shapes compile on-the-fly and are immediately discarded. No unbounded growth.

Warmup formula (derived from CLT on the 2-D shape distribution):

eff_2d    = sqrt(BS) * std(node_counts) * sqrt(BS) * std(edge_counts) * 2π
freeze_at = min(total_iters // 4, max(5, int(sqrt(eff_2d))))

Computed from actual dataset statistics in DataLoader.__init__. Not hand-tuned.

Usage

Zero user changes required — the freeze happens automatically:

loader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(100):
    for batch in loader:          # freeze_graph_cache() fires at iter freeze_at, first epoch only
        batch = batch.to(device)
        loss = model(batch).loss
        loss.backward()
        optimizer.step()

Benchmark results

MacBook Pro M1 Max, PyTorch 2.13.0a0, ZINC subset (12k molecules, BS=32).
Training loop (forward + backward + optimizer.step).
eff_2d=9,536, freeze_at=50 — computed from the ZINC node/edge distribution.

Strategy               ΔRSS MB   ms/iter   note
────────────────────────────────────────────────────────────────
always (baseline)      +1370.7    43.2ms   → ~1.7 GB / 100 epochs
freeze_after_warmup       +0.0    69.6ms   (1.61×, flat)
clear_per_iter (prev)     +0.0   112.0ms   (2.59×, flat)
never                     +0.0   143.7ms   (3.33×, flat)

freeze_after_warmup is 38% faster than clearing after each iteration and 52% faster than never-cache. The clear_per_iter overhead comes from destroying forward→backward graph reuse every iteration; freeze preserves compiled graphs for warmup-seen shapes through the full forward+backward cycle.

Reproduce with: python3 benchmark/mps_graph_cache.py

Changes

  • torch_geometric/loader/dataloader.py: add _mps_freeze_at() helper and override DataLoader.__iter__ to call torch.mps.freeze_graph_cache() at the computed warmup step
  • benchmark/mps_graph_cache.py: self-contained ZINC training benchmark reproducing the above results

The feature is a strict no-op when:

  • MPS is unavailable (non-Apple hardware)
  • torch.mps.freeze_graph_cache is absent (PyTorch < 2.13)
  • the dataset does not expose num_nodes / num_edges
  • node or edge counts are uniform (std == 0, caching is already optimal)

Test Plan

  • ZINC training on MPS: always → +1370 MB RSS over 200 iters; freeze_after_warmup → +0.0 MB RSS (flat)
  • Performance: freeze is 1.61× vs always baseline (vs 2.59× for clear_per_iter)
  • Benchmark: benchmark/mps_graph_cache.py reproduces these results
  • CI: existing DataLoader tests pass (freeze is a no-op until PyTorch 2.13 merges)

cc @rusty1s @mananshah99 @akihironitta

@anagnorisis2peripeteia anagnorisis2peripeteia marked this pull request as draft May 6, 2026 14:19
@anagnorisis2peripeteia anagnorisis2peripeteia changed the title Auto-set MPS graph cache policy to never for shuffled DataLoaders loader: freeze MPS graph cache after warmup for graph DataLoaders May 6, 2026
@anagnorisis2peripeteia anagnorisis2peripeteia marked this pull request as ready for review May 6, 2026 23:36
PyG's Batch.from_data_list() concatenates variable-size graphs into a
single tensor keyed by (sum_nodes, sum_edges).  With a shuffled DataLoader
every batch produces a unique shape, so MPSGraph compiles a new computation
graph for each batch and caches it indefinitely.  On a 10k-molecule dataset
with batch_size=32 over 100 epochs this produces ~31,000 unique shapes at
~55 KB each — ~1.7 GB of graph cache that is never reused.

DataLoader now calls torch.mps.freeze_graph_cache() automatically after a
short warmup phase.  The warmup length is derived from the dataset's node
and edge count distributions using the effective 2-D shape space formula:

  eff_2d    = sqrt(BS) * std(node_counts) * sqrt(BS) * std(edge_counts) * 2π
  freeze_at = max(5, int(sqrt(eff_2d)))

After freeze the cache is read-only: hits are served from compiled graphs,
rare shapes compile on-the-fly and are immediately discarded — no unbounded
growth.

Benchmark (M1 Max, PyTorch 2.13, ZINC subset, BS=32, training loop):
  always (baseline)    +1370 MB RSS   43 ms/iter  → ~1.7 GB/100 epochs
  freeze_after_warmup     +0 MB RSS   70 ms/iter  (1.61x, flat)
  clear_per_iter (prev)   +0 MB RSS  112 ms/iter  (2.59x, flat)
  never                   +0 MB RSS  144 ms/iter  (3.33x, flat)

freeze_after_warmup is 38% faster than clearing after each iteration and
52% faster than never-cache, because it preserves compiled graphs for shapes
seen during warmup through the full forward+backward+step cycle.

The feature is a no-op when:
- MPS is unavailable
- torch.mps.freeze_graph_cache is absent (PyTorch < 2.13)
- the dataset does not expose num_nodes / num_edges statistics
- node or edge counts are uniform (std == 0, caching already optimal)

Depends on: pytorch/pytorch#182648
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant