Conversation
e6e1291 to
1dd2c63
Compare
angeloskath
left a comment
There was a problem hiding this comment.
I think it looks absolutely fine.
Basically one can implement exactly export using a callback which means it's great. I guess the only thing to talk about is the use of the namer instead of uint64_t and perhaps the fact that serialize isn't actually implemented as a callback.
The main thing would be that FunctionExporter currently has state eg for saving the constants only once which is not implementable with the callback as is now (it would by providing uint64_t ids instead of names)..
Good point re constants. The namer could also be persistent on the
It's a good idea.. but somewhat messy in practice. The export with callback path does a lot of work to make all these different types (shapes, dtypes, all the primitive state) fit into one type of datastructure which is very convenient for Python but basically wasted work for C++ that would need to be undone. Even so it may be worth it to have a single code path for that.. I'll play around with it a bit. |
1dd2c63 to
69e3f65
Compare
Another option (which I just implemented cause it is quite simple) is to only send in new constants when doing multiple graphs per FunctionExporter. It seems like what one would do anyway .. so it saves some work for the user.. but maybe it's not expected? |
69e3f65 to
9fcfcf0
Compare
Another thing that is slightly different with the callback is for the regular path is how we deal with kwarg inputs. For the regular path we save the kwargs and when you call it in C++ you have to use a dictionary like We might want to add kwargs to the callback path. And in general, I think we may need to do something in the callback case to ensure the inputs are well-ordered 🤔 . Like vs It's not clear from the exported context how the order of the function arguments map to the inputs / primitive's inputs. |
|
I guess to add to the above, the order of the inputs sent to the callback will match the order in the function call. So that would print: And Then the order would be lost. So I think it does make sense to keep the non-kwarg inputs separate from the kwarg inputs to the callback. CC @junpeiz in case you have some thoughts on that. |
|
Hi, thanks for implementing export function. I would like to make couple of suggestions, please do correct me if my understanding is wrong.
def test_fn(x):
return mx.sum(x, axis=1, keepdims=True)
def callback(args):
print(args)
x = mx.random.normal((1, 16, 16, 3))
mx.export_function(callback, test_fn, x)Output: {'inputs': [('A', (1, 16, 16, 3), mlx.core.float32)]}
{'outputs': [('B', (1, 1, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [2, [1]], 'primitive': 'Reduce', 'outputs': [('B', (1, 1, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32)]}
|
|
|
For the second point, I performed the following experiment. The default implementation preserves the order in function definition. (In export, always the x is first and y is second). def test_fn(x, y):
return x - y
def callback(args):
print(args)
x = mx.random.normal((1, 16, 16, 3))
y = mx.random.normal((1,))
print(f"---- default call ----")
mx.export_function(callback, test_fn, x, y)
print(f"---- keyword call ---- y=y, x=x")
mx.export_function(callback, test_fn, y=y, x=x)Output: ---- default call ----
{'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('B', (1,), mlx.core.float32)]}
{'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [(1, 16, 16, 3)], 'primitive': 'Broadcast', 'outputs': [('D', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('B', (1,), mlx.core.float32)]}
{'state': [], 'primitive': 'Subtract', 'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('D', (1, 16, 16, 3), mlx.core.float32)]}
---- keyword call ---- y=y, x=x
{'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('B', (1,), mlx.core.float32)]}
{'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)]}
{'constants': []}
{'state': [(1, 16, 16, 3)], 'primitive': 'Broadcast', 'outputs': [('D', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('B', (1,), mlx.core.float32)]}
{'state': [], 'primitive': 'Subtract', 'outputs': [('C', (1, 16, 16, 3), mlx.core.float32)], 'inputs': [('A', (1, 16, 16, 3), mlx.core.float32), ('D', (1, 16, 16, 3), mlx.core.float32)]} |
|
No the point is about keyword arguments not positional arguments. Positional arguments will have their order preserved as expected. And actually it uncovered a bug in the current implementation. Here's is a test I just added which currently fails: def fn(x, y):
return x - y
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
imported = mx.import_function(path)
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
self.assertEqual(out.item(), -1.0)
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
self.assertEqual(out.item(), 1.0) |
|
OK regarding 2 and 3 I made a couple updates. Now it includes a You can export a function with kwargs to test it out: import mlx.core as mx
def fn(x, y):
return x - y
def callback(args):
print(args)
mx.export_function(callback, fn, x=mx.array(1.0), y=mx.array(1.0))You should get the following: |
|
Thanks, this looks cool ! |
|
I think this is ok to merge. The interface between the callback and export may need to change based on feedback, but so far I think it is good enough to land. |
* export with callback * export with callback * Add types, fix kwarg ordering bug + test * cleanup, test, fix * typos
Basic usage: