Skip to content

Fix compile multi capture#2678

Merged
awni merged 2 commits intomainfrom
fix_compiel_multi_capture
Nov 3, 2025
Merged

Fix compile multi capture#2678
awni merged 2 commits intomainfrom
fix_compiel_multi_capture

Conversation

@awni
Copy link
Copy Markdown
Member

@awni awni commented Oct 16, 2025

A possible solution to #2674.

Basically deep copy the graph before applying optimizations in compile_simplify. Would like to measure overhead from this before merging it.

I don't see any noticeable difference in compile times on the MNIST 100-layer benchmark.

@awni awni force-pushed the fix_compiel_multi_capture branch from e5ea030 to 520e9cb Compare October 16, 2025 00:10
@awni awni changed the title [WIP Fix compile multi capture [WIP] Fix compile multi capture Oct 16, 2025
@awni awni force-pushed the fix_compiel_multi_capture branch from 520e9cb to c473719 Compare October 16, 2025 14:27
@awni awni force-pushed the fix_compiel_multi_capture branch from 017da68 to f8b6f8a Compare October 16, 2025 14:41
@awni awni changed the title [WIP] Fix compile multi capture Fix compile multi capture Oct 16, 2025
@awni
Copy link
Copy Markdown
Member Author

awni commented Oct 16, 2025

This should be working and there is minor to no perf penalty. So I'll move it out of draft. It's not a super elegant solution but I don't see a better way.. open to ideas.

@awni awni marked this pull request as ready for review October 16, 2025 14:43
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, assuming simplify and fuse are disabled, previously we would reuse the computation of b since both compiles would be using the same graph but now the computation will be done twice.

Personally I think this is completely fine and a pretty good solution actually. If a user is so conscious about the evaluation of these constant they need only evaluate them manually before compiling and the computation will be shared.

A minor caveat in which case one would need to pre-evaluate some constants would be something of the following form:

c = mx.ones((1024, 1024, 1024)) * 4
fs = [mx.compile(lambda x: x * c * i) for i in range(10)]

which would keep a copy of c per function in fs (after it is called once) while previously it wouldn't.

Let me know if I am missing something.

@awni
Copy link
Copy Markdown
Member Author

awni commented Oct 28, 2025

previously we would reuse the computation of b since both compiles would be using the same graph but now the computation will be done twice.

I don't think that's the case. Previously we would still recompute the captured part of the graph for both because when you actually evaluate the compiled tape there is no short-circuit for already evaluated subparts of the graph (maybe there should be?).

which would keep a copy of c per function in fs (after it is called once) while previously it wouldn't.

I think for your example it would also recompute c per function. So even before we have this sub-optimal behavior.

In Python we can attempt to fix it by inspecting the closure and treating captured arrays as implicit inputs. (I had something implemented like that a while back if you recall, but it gets a bit messy). In C++ I don't think we can distinguish between constants created inside the compiled function vs captured inputs without some serious shenanigans.

@awni awni merged commit 93d76b0 into main Nov 3, 2025
6 checks passed
@awni awni deleted the fix_compiel_multi_capture branch November 3, 2025 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants