Skip to content

Commit 4d0a951

Browse files
committed
Use entry points to locate backends
Use entry points to locate Triton backends instead of iterating over the `triton.backends` package directory. This fixes support for nontrivial install layouts, particularly editable installs made using `package_dir`. Now `setup.py` keeps record of all installed backends in entry points, and `triton.backends` uses them to locate and load the backends. This also technically permits installing third-party backends without altering the Triton installation, by adding additional entry points.
1 parent 8afcacd commit 4d0a951

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

python/triton/backends/__init__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import importlib
2-
import inspect
3-
import os
2+
import sys
43
from dataclasses import dataclass
54
from .driver import DriverBase
65
from .compiler import BaseBackend
76

7+
if sys.version_info >= (3, 10):
8+
from importlib.metadata import entry_points
9+
else:
10+
from importlib_metadata import entry_points
11+
812

913
def _find_concrete_subclasses(module, base_class):
1014
ret = []
@@ -27,16 +31,11 @@ class Backend:
2731

2832
def _discover_backends():
2933
backends = dict()
30-
root = os.path.dirname(__file__)
31-
for name in os.listdir(root):
32-
if not os.path.isdir(os.path.join(root, name)):
33-
continue
34-
if name.startswith('__'):
35-
continue
36-
compiler = importlib.import_module(f"triton.backends.{name}.compiler")
37-
driver = importlib.import_module(f"triton.backends.{name}.driver")
38-
backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
39-
_find_concrete_subclasses(driver, DriverBase))
34+
for ep in entry_points().select(group="triton.backends"):
35+
compiler = importlib.import_module(f"{ep.value}.compiler")
36+
driver = importlib.import_module(f"{ep.value}.driver")
37+
backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
38+
_find_concrete_subclasses(driver, DriverBase))
4039
return backends
4140

4241

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def get_entry_points():
598598
"proton-viewer = triton.profiler.viewer:main",
599599
"proton = triton.profiler.proton:main",
600600
]
601+
entry_points["triton.backends"] = [f"{b.name} = triton.backends.{b.name}" for b in backends]
601602
return entry_points
602603

603604

@@ -638,7 +639,10 @@ def get_git_version_suffix():
638639
author_email="[email protected]",
639640
description="A language and compiler for custom Deep Learning operations",
640641
long_description="",
641-
install_requires=["setuptools>=40.8.0"],
642+
install_requires=[
643+
"setuptools>=40.8.0",
644+
"importlib-metadata; python_version < '3.10'",
645+
],
642646
packages=find_packages(where="python") + extra_packages,
643647
package_dir=package_dirs,
644648
entry_points=get_entry_points(),

0 commit comments

Comments
 (0)