Skip to content

Export with callback#2612

Merged
awni merged 5 commits intomainfrom
export_callback
Oct 9, 2025
Merged

Export with callback#2612
awni merged 5 commits intomainfrom
export_callback

Conversation

@awni
Copy link
Copy Markdown
Member

@awni awni commented Sep 22, 2025

Basic usage:

import mlx.core as mx

def fn(x):
    return mx.log(mx.abs(x))

def callback(args):
    print(args)

mx.export_function(callback, fn, mx.ones((2, 2)))

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it looks absolutely fine.

Basically one can implement exactly export using a callback which means it's great. I guess the only thing to talk about is the use of the namer instead of uint64_t and perhaps the fact that serialize isn't actually implemented as a callback.

The main thing would be that FunctionExporter currently has state eg for saving the constants only once which is not implementable with the callback as is now (it would by providing uint64_t ids instead of names)..

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 23, 2025

I guess the only thing to talk about is the use of the namer instead of uint64_t

Good point re constants. The namer could also be persistent on the FunctionExporter. So it's really a stylistic preference if you prefer character names or integer identifiers?

the fact that serialize isn't actually implemented as a callback.

It's a good idea.. but somewhat messy in practice. The export with callback path does a lot of work to make all these different types (shapes, dtypes, all the primitive state) fit into one type of datastructure which is very convenient for Python but basically wasted work for C++ that would need to be undone.

Even so it may be worth it to have a single code path for that.. I'll play around with it a bit.

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 23, 2025

Good point re constants. The namer could also be persistent on the FunctionExporter. So it's really a stylistic preference if you prefer character names or integer identifiers?

Another option (which I just implemented cause it is quite simple) is to only send in new constants when doing multiple graphs per FunctionExporter. It seems like what one would do anyway .. so it saves some work for the user.. but maybe it's not expected?

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 23, 2025

the fact that serialize isn't actually implemented as a callback.

Another thing that is slightly different with the callback is for the regular path is how we deal with kwarg inputs.

For the regular path we save the kwargs and when you call it in C++ you have to use a dictionary like {"x": x}

We might want to add kwargs to the callback path. And in general, I think we may need to do something in the callback case to ensure the inputs are well-ordered 🤔 . Like

def fun(a, b):
  return a - b

vs

def fun(a, b):
  return b - a

It's not clear from the exported context how the order of the function arguments map to the inputs / primitive's inputs.

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 24, 2025

I guess to add to the above, the order of the inputs sent to the callback will match the order in the function call.

def fn(x, y):
    return x - y

def callback(args):
    print(args)

mx.export_function(callback, fn, mx.ones((2, 2)), mx.ones((2, 2)))

So that would print:

{'inputs': [('A', (2, 2), mlx.core.float32), ('B', (2, 2), mlx.core.float32)]}
{'outputs': [('C', (2, 2), mlx.core.float32)]}
{'constants': []}
{'state': [], 'primitive': 'Subtract', 'outputs': [('C', (2, 2), mlx.core.float32)], 'inputs': [('A', (2, 2), mlx.core.float32), ('B', (2, 2), mlx.core.float32)]}

And A and B are X and Y respectively. It can get messy with kwargs though:

mx.export_function(callback, fn, y=mx.ones((2, 2)), x=mx.ones((2, 2)))

Then the order would be lost. So I think it does make sense to keep the non-kwarg inputs separate from the kwarg inputs to the callback.

CC @junpeiz in case you have some thoughts on that.

@gokulkrishna98
Copy link
Copy Markdown

Hi, thanks for implementing export function. I would like to make couple of suggestions, please do correct me if my understanding is wrong.

  1. Currently the state in primitive callback is implemented as array, it would be helpful if it can be a dict with meaningful keywords.
def test_fn(x):
    return mx.sum(x, axis=1, keepdims=True)

def callback(args):
    print(args)

x = mx.random.normal((1, 16, 16, 3))
mx.export_function(callback, test_fn, x)

Output:

{'inputs': [('A', (1, 16, 16, 3), mlx.core.float32)]}
{'outputs': [('B', (1, 1, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [2, [1]], 'primitive': 'Reduce', 'outputs': [('B', (1, 1, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32)]} 
  1. In the context of keyword arguments, Did you mean, there is going to be constraint, that the order of array info in input callback is always going to match the order given by the function definition (not the way it is passed via keyword args) ?

  2. Quality of life improvement suggestion, is it possible to introduce a key which describes what type of callback has occurred ?

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 24, 2025

  1. The state is an array. Adding names is possible in theory. This is kind of where the abstraction breaks because either way you have to know about MLX primitives (inspect them) to know what the arguments are. So to some extent that mapping has to be done regardless of if we include names or not. I didn't include names because doing it in a clean way is not so simple.. but I will chew on it a bit.

  2. The order of inputs matches the order they are passed to the function (regardless of keyword arguments or not). So if I call a function like f(x=x, y=y) then x will be first and y will be second. In Python you could call the same function like f(y=y, x=x) and then in the export the order would be y first. So in order to resolve that ambiguity I think it is necessary to make a distinction between keyword arguments and non-keyword arguments during export.

  3. Yes that is posible.

@gokulkrishna98
Copy link
Copy Markdown

gokulkrishna98 commented Sep 24, 2025

For the second point, I performed the following experiment. The default implementation preserves the order in function definition. (In export, always the x is first and y is second).

def test_fn(x, y):
    return x - y

def callback(args):
    print(args)

x = mx.random.normal((1, 16, 16, 3))
y = mx.random.normal((1,))

print(f"---- default call ----")
mx.export_function(callback, test_fn, x, y)
print(f"---- keyword call ---- y=y, x=x")
mx.export_function(callback, test_fn, y=y, x=x)

Output:

---- default call ----
{'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('B', (1,), mlx.core.float32)]}
{'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [(1, 16, 16, 3)], 'primitive': 'Broadcast', 'outputs': [('D', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('B', (1,), mlx.core.float32)]}
{'state': [], 'primitive': 'Subtract', 'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('D', (1, 16, 16, 3), mlx.core.float32)]}
---- keyword call ---- y=y, x=x
{'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('B', (1,), mlx.core.float32)]}
{'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [(1, 16, 16, 3)], 'primitive': 'Broadcast', 'outputs': [('D', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('B', (1,), mlx.core.float32)]}
{'state': [], 'primitive': 'Subtract', 'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('D', (1, 16, 16, 3), mlx.core.float32)]}

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 24, 2025

No the point is about keyword arguments not positional arguments. Positional arguments will have their order preserved as expected. And actually it uncovered a bug in the current implementation.

Here's is a test I just added which currently fails:

        def fn(x, y):
            return x - y
        
        mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
        imported = mx.import_function(path)
        out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
        self.assertEqual(out.item(), -1.0)
        out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
        self.assertEqual(out.item(), 1.0)

@awni
Copy link
Copy Markdown
Member Author

awni commented Sep 24, 2025

OK regarding 2 and 3 I made a couple updates. Now it includes a type field for every callback call. There are 5 types:
inputs, outputs, constants, primitive, and now keyword_inputs. The last one contains a list which maps the keyword names to the array names in the graph.

You can export a function with kwargs to test it out:

import mlx.core as mx

def fn(x, y):
    return x - y

def callback(args):
    print(args)

mx.export_function(callback, fn, x=mx.array(1.0), y=mx.array(1.0))

You should get the following:

{'inputs': [('A', (), mlx.core.float32), ('B', (), mlx.core.float32)], 'type': 'inputs'}
{'keywords': [('x', 'A'), ('y', 'B')], 'type': 'keyword_inputs'}
{'outputs': [('C', (), mlx.core.float32)], 'type': 'outputs'}
{'constants': [], 'type': 'constants'}
{'arguments': [], 'outputs': [('C', (), mlx.core.float32)], 'name': 'Subtract', 'inputs': [('A', (), mlx.core.float32), ('B', (), mlx.core.float32)], 'type': 'primitive'}

@gokulkrishna98
Copy link
Copy Markdown

Thanks, this looks cool !

@awni awni marked this pull request as ready for review October 8, 2025 22:46
@awni
Copy link
Copy Markdown
Member Author

awni commented Oct 8, 2025

I think this is ok to merge. The interface between the callback and export may need to change based on feedback, but so far I think it is good enough to land.

@awni awni requested a review from angeloskath October 8, 2025 22:47
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@awni awni merged commit e89e8b4 into main Oct 9, 2025
6 checks passed
@awni awni deleted the export_callback branch October 9, 2025 02:24
faisalmemon pushed a commit to faisalmemon/mlx that referenced this pull request Oct 30, 2025
* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants