Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions python/mlx/_stub_patterns.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
mlx.core.__prefix__:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import sys
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
bool_: Dtype = ...

mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core.distributed import Group
from typing import Sequence, Optional, Union

mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union

mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Tuple, Union

mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union

mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def run(self) -> None:
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")


class MLXBdistWheel(bdist_wheel):
Expand Down