Skip to content

Commit bf2eea0

Browse files
authored
Ko3n1g/ci/fix dependency tree (#13448)
* ci: Fix deps tree for tests Signed-off-by: oliver könig <[email protected]> * fix Signed-off-by: oliver könig <[email protected]> * fix Signed-off-by: oliver könig <[email protected]> * ci: Small fixes Signed-off-by: oliver könig <[email protected]> * f Signed-off-by: oliver könig <[email protected]> * ci: Deps tree Signed-off-by: oliver könig <[email protected]> * f Signed-off-by: oliver könig <[email protected]> * ci: Solve __init__ imports Signed-off-by: oliver könig <[email protected]> --------- Signed-off-by: oliver könig <[email protected]>
1 parent d4da6b4 commit bf2eea0

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

.github/scripts/nemo_dependencies.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,53 @@ def find_python_files(directory: str) -> List[str]:
4444
def analyze_imports(nemo_root: str, file_path: str) -> Set[str]:
4545
"""Analyze a Python file and return its NeMo package dependencies using AST parsing."""
4646
imports = set()
47+
visited = set() # Track visited modules to prevent circular imports
48+
49+
def get_init_imports(module_path: str, depth: int = 0) -> Dict[str, str]:
50+
"""Recursively analyze imports from __init__.py files and map them to their final destinations."""
51+
# Prevent infinite recursion
52+
if depth > 10 or module_path in visited: # Limit depth to 10 levels
53+
return {}
54+
55+
visited.add(module_path)
56+
init_path = os.path.join(module_path, '__init__.py')
57+
if not os.path.exists(init_path):
58+
return {}
59+
60+
try:
61+
with open(init_path, 'r', encoding='utf-8') as f:
62+
init_tree = ast.parse(f.read(), filename=init_path)
63+
64+
import_map = {}
65+
for node in ast.walk(init_tree):
66+
if isinstance(node, ast.ImportFrom) and node.module and node.module.startswith('nemo.'):
67+
if node.names:
68+
for name in node.names:
69+
if name.name == '*':
70+
continue
71+
72+
# Get the full module path for the import
73+
module_parts = node.module.split('.')
74+
module_dir = os.path.join(nemo_root, *module_parts)
75+
76+
# If the imported module has an __init__.py, recursively analyze it
77+
if os.path.exists(os.path.join(module_dir, '__init__.py')):
78+
sub_imports = get_init_imports(module_dir, depth + 1)
79+
if name.name in sub_imports:
80+
import_map[name.name] = sub_imports[name.name]
81+
else:
82+
# If not found in sub-imports, it might be from the module itself
83+
module_file = os.path.join(module_dir, f"{module_parts[-1]}.py")
84+
if os.path.exists(module_file):
85+
import_map[name.name] = f"{node.module}.{name.name}"
86+
else:
87+
# Direct module import
88+
import_map[name.name] = f"{node.module}.{name.name}"
89+
90+
return import_map
91+
except Exception as e:
92+
print(f"Error analyzing {init_path}: {e}")
93+
return {}
4794

4895
try:
4996
with open(file_path, 'r', encoding='utf-8') as f:
@@ -68,14 +115,31 @@ def analyze_imports(nemo_root: str, file_path: str) -> Set[str]:
68115
if name.name == '*':
69116
continue
70117

71-
imports.add(f"{node.module}.{name.name}")
118+
# Check if this is an __init__ import
119+
module_path = os.path.join(nemo_root, *parts)
120+
init_imports = get_init_imports(module_path)
121+
122+
if name.name in init_imports:
123+
# Use the mapped import path
124+
imports.add(init_imports[name.name])
125+
else:
126+
imports.add(f"{node.module}.{name.name}")
72127

73128
elif module_type in find_top_level_packages(nemo_root):
74129
if node.names:
75130
for name in node.names:
76131
if name.name == '*':
77132
continue
78-
imports.add(f"{node.module}.{name.name}")
133+
134+
# Check if this is an __init__ import
135+
module_path = os.path.join(nemo_root, *parts)
136+
init_imports = get_init_imports(module_path)
137+
138+
if name.name in init_imports:
139+
# Use the mapped import path
140+
imports.add(init_imports[name.name])
141+
else:
142+
imports.add(f"{node.module}.{name.name}")
79143

80144
except Exception as e:
81145
print(f"Error analyzing {file_path}: {e}")
@@ -256,7 +320,7 @@ def build_dependency_graph(nemo_root: str) -> Dict[str, List[str]]:
256320
new_deps.append("unit-tests")
257321

258322
if (
259-
"nemo.collections" in deps
323+
"nemo.collections" in dep
260324
and "nemo.collections.asr" not in dep
261325
and "nemo.collections.tts" not in dep
262326
and "nemo.collections.speechlm" not in dep

0 commit comments

Comments
 (0)