@@ -44,6 +44,53 @@ def find_python_files(directory: str) -> List[str]:
4444def analyze_imports (nemo_root : str , file_path : str ) -> Set [str ]:
4545 """Analyze a Python file and return its NeMo package dependencies using AST parsing."""
4646 imports = set ()
47+ visited = set () # Track visited modules to prevent circular imports
48+
49+ def get_init_imports (module_path : str , depth : int = 0 ) -> Dict [str , str ]:
50+ """Recursively analyze imports from __init__.py files and map them to their final destinations."""
51+ # Prevent infinite recursion
52+ if depth > 10 or module_path in visited : # Limit depth to 10 levels
53+ return {}
54+
55+ visited .add (module_path )
56+ init_path = os .path .join (module_path , '__init__.py' )
57+ if not os .path .exists (init_path ):
58+ return {}
59+
60+ try :
61+ with open (init_path , 'r' , encoding = 'utf-8' ) as f :
62+ init_tree = ast .parse (f .read (), filename = init_path )
63+
64+ import_map = {}
65+ for node in ast .walk (init_tree ):
66+ if isinstance (node , ast .ImportFrom ) and node .module and node .module .startswith ('nemo.' ):
67+ if node .names :
68+ for name in node .names :
69+ if name .name == '*' :
70+ continue
71+
72+ # Get the full module path for the import
73+ module_parts = node .module .split ('.' )
74+ module_dir = os .path .join (nemo_root , * module_parts )
75+
76+ # If the imported module has an __init__.py, recursively analyze it
77+ if os .path .exists (os .path .join (module_dir , '__init__.py' )):
78+ sub_imports = get_init_imports (module_dir , depth + 1 )
79+ if name .name in sub_imports :
80+ import_map [name .name ] = sub_imports [name .name ]
81+ else :
82+ # If not found in sub-imports, it might be from the module itself
83+ module_file = os .path .join (module_dir , f"{ module_parts [- 1 ]} .py" )
84+ if os .path .exists (module_file ):
85+ import_map [name .name ] = f"{ node .module } .{ name .name } "
86+ else :
87+ # Direct module import
88+ import_map [name .name ] = f"{ node .module } .{ name .name } "
89+
90+ return import_map
91+ except Exception as e :
92+ print (f"Error analyzing { init_path } : { e } " )
93+ return {}
4794
4895 try :
4996 with open (file_path , 'r' , encoding = 'utf-8' ) as f :
@@ -68,14 +115,31 @@ def analyze_imports(nemo_root: str, file_path: str) -> Set[str]:
68115 if name .name == '*' :
69116 continue
70117
71- imports .add (f"{ node .module } .{ name .name } " )
118+ # Check if this is an __init__ import
119+ module_path = os .path .join (nemo_root , * parts )
120+ init_imports = get_init_imports (module_path )
121+
122+ if name .name in init_imports :
123+ # Use the mapped import path
124+ imports .add (init_imports [name .name ])
125+ else :
126+ imports .add (f"{ node .module } .{ name .name } " )
72127
73128 elif module_type in find_top_level_packages (nemo_root ):
74129 if node .names :
75130 for name in node .names :
76131 if name .name == '*' :
77132 continue
78- imports .add (f"{ node .module } .{ name .name } " )
133+
134+ # Check if this is an __init__ import
135+ module_path = os .path .join (nemo_root , * parts )
136+ init_imports = get_init_imports (module_path )
137+
138+ if name .name in init_imports :
139+ # Use the mapped import path
140+ imports .add (init_imports [name .name ])
141+ else :
142+ imports .add (f"{ node .module } .{ name .name } " )
79143
80144 except Exception as e :
81145 print (f"Error analyzing { file_path } : { e } " )
@@ -256,7 +320,7 @@ def build_dependency_graph(nemo_root: str) -> Dict[str, List[str]]:
256320 new_deps .append ("unit-tests" )
257321
258322 if (
259- "nemo.collections" in deps
323+ "nemo.collections" in dep
260324 and "nemo.collections.asr" not in dep
261325 and "nemo.collections.tts" not in dep
262326 and "nemo.collections.speechlm" not in dep
0 commit comments