Skip to content

Commit 48f9ab0

Browse files
Address comments
1 parent 723b5b8 commit 48f9ab0

8 files changed

Lines changed: 284 additions & 263 deletions

File tree

shared/kpack/python/rocm_kpack/coff/surgery.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,54 @@ def load(cls, path: Path) -> "CoffSurgery":
179179
data = bytearray(path.read_bytes())
180180
return cls(data, path)
181181

182+
@staticmethod
183+
def has_fatbin_section(file_path: Path) -> bool:
184+
"""Fast check for .hip_fat section without loading the full binary.
185+
186+
Reads only the DOS header, PE signature, COFF header, and section
187+
headers — a few KB total even for large binaries.
188+
Returns False for non-PE files (wrong magic).
189+
Raises on I/O errors or corrupt PE headers.
190+
"""
191+
with open(file_path, "rb") as f:
192+
dos = f.read(DOS_HEADER_SIZE)
193+
if len(dos) < DOS_HEADER_SIZE:
194+
return False
195+
if dos[:2] != b"MZ":
196+
return False
197+
198+
pe_offset = struct.unpack_from("<I", dos, 0x3C)[0]
199+
f.seek(pe_offset)
200+
pe_sig = f.read(4)
201+
if pe_sig != PE_SIGNATURE:
202+
return False
203+
204+
# COFF header: NumberOfSections at offset 2
205+
coff = f.read(COFF_HEADER_SIZE)
206+
if len(coff) < COFF_HEADER_SIZE:
207+
return False
208+
num_sections = struct.unpack_from("<H", coff, 2)[0]
209+
opt_hdr_size = struct.unpack_from("<H", coff, 16)[0]
210+
211+
# Section headers start after optional header
212+
section_offset = pe_offset + 4 + COFF_HEADER_SIZE + opt_hdr_size
213+
f.seek(section_offset)
214+
section_data = f.read(num_sections * SECTION_HEADER_SIZE)
215+
if len(section_data) < num_sections * SECTION_HEADER_SIZE:
216+
return False
217+
218+
# Check each section name for ".hip_fat"
219+
target = b".hip_fat"
220+
for i in range(num_sections):
221+
name = section_data[
222+
i * SECTION_HEADER_SIZE : i * SECTION_HEADER_SIZE + 8
223+
]
224+
# Section names are 8 bytes, null-padded
225+
if name == target or name.rstrip(b"\x00") == target:
226+
return True
227+
228+
return False
229+
182230
# =========================================================================
183231
# Properties
184232
# =========================================================================

shared/kpack/python/rocm_kpack/database_handlers.py

Lines changed: 78 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,29 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
5252
Detect if a file belongs to this database type and extract bundle key.
5353
5454
Args:
55-
path: Path to the file
55+
path: Path to the file (must be under prefix_root)
5656
prefix_root: Root of the prefix for relative path computation
5757
5858
Returns:
5959
Bundle key (e.g., 'gfx1100', 'gfx11', 'gfx12_0') if file matches,
6060
None otherwise. Bundle keys correspond to entries in the
6161
rocm-bootstrap hierarchy.
62+
63+
Raises:
64+
ValueError: If path is not under prefix_root (caller bug).
6265
"""
6366
pass
6467

68+
def _relative_path(self, path: Path, prefix_root: Path) -> str:
69+
"""Compute forward-slash relative path, raising on bad input."""
70+
try:
71+
return path.relative_to(prefix_root).as_posix()
72+
except ValueError:
73+
raise ValueError(
74+
f"{self.name()} handler: path {path} is not under "
75+
f"prefix_root {prefix_root}"
76+
) from None
77+
6578
def should_move(self, path: Path) -> bool:
6679
"""
6780
Determine if this file should be moved to architecture-specific artifact.
@@ -88,32 +101,26 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
88101
89102
Pattern: lib/rocblas/library/*_gfx*.{co,hsaco,dat}
90103
"""
91-
try:
92-
rel_path = path.relative_to(prefix_root)
93-
# Use as_posix() for consistent forward-slash paths on all platforms
94-
path_str = rel_path.as_posix()
95-
96-
# Check if it's in rocblas/library directory
97-
if "rocblas/library" not in path_str:
98-
return None
99-
100-
# Check file extension
101-
if path.suffix not in [".co", ".hsaco", ".dat"]:
102-
return None
103-
104-
# Extract architecture from filename
105-
# Look for patterns like _gfx1100, _gfx1101, gfx1102, etc.
106-
match = _GFX_ARCH_PATTERN.search(path.name)
107-
if match:
108-
return match.group(0)
109-
110-
# Some .dat files don't have architecture suffix but are generic
111-
# We don't move those
104+
path_str = self._relative_path(path, prefix_root)
105+
106+
# Check if it's in rocblas/library directory
107+
if "rocblas/library" not in path_str:
112108
return None
113109

114-
except (ValueError, AttributeError):
110+
# Check file extension
111+
if path.suffix not in [".co", ".hsaco", ".dat"]:
115112
return None
116113

114+
# Extract architecture from filename
115+
# Look for patterns like _gfx1100, _gfx1101, gfx1102, etc.
116+
match = _GFX_ARCH_PATTERN.search(path.name)
117+
if match:
118+
return match.group(0)
119+
120+
# Some .dat files don't have architecture suffix but are generic
121+
# We don't move those
122+
return None
123+
117124

118125
class HipBLASLtHandler(DatabaseHandler):
119126
"""Handler for hipBLASLt kernel files."""
@@ -127,29 +134,23 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
127134
128135
Pattern: lib/hipblaslt/library/*_gfx*.{co,hsaco,dat}
129136
"""
130-
try:
131-
rel_path = path.relative_to(prefix_root)
132-
# Use as_posix() for consistent forward-slash paths on all platforms
133-
path_str = rel_path.as_posix()
134-
135-
# Check if it's in hipblaslt/library directory
136-
if "hipblaslt/library" not in path_str:
137-
return None
138-
139-
# Check file extension
140-
if path.suffix not in [".co", ".hsaco", ".dat"]:
141-
return None
142-
143-
# Extract architecture from filename
144-
match = _GFX_ARCH_PATTERN.search(path.name)
145-
if match:
146-
return match.group(0)
137+
path_str = self._relative_path(path, prefix_root)
147138

139+
# Check if it's in hipblaslt/library directory
140+
if "hipblaslt/library" not in path_str:
148141
return None
149142

150-
except (ValueError, AttributeError):
143+
# Check file extension
144+
if path.suffix not in [".co", ".hsaco", ".dat"]:
151145
return None
152146

147+
# Extract architecture from filename
148+
match = _GFX_ARCH_PATTERN.search(path.name)
149+
if match:
150+
return match.group(0)
151+
152+
return None
153+
153154

154155
class HipSparseLtHandler(DatabaseHandler):
155156
"""Handler for hipSPARSELt Tensile kernel files."""
@@ -163,25 +164,20 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
163164
164165
Pattern: lib/hipsparselt/library/*_gfx*.{co,hsaco,dat}
165166
"""
166-
try:
167-
rel_path = path.relative_to(prefix_root)
168-
path_str = rel_path.as_posix()
169-
170-
if "hipsparselt/library" not in path_str:
171-
return None
172-
173-
if path.suffix not in [".co", ".hsaco", ".dat"]:
174-
return None
175-
176-
match = _GFX_ARCH_PATTERN.search(path.name)
177-
if match:
178-
return match.group(0)
167+
path_str = self._relative_path(path, prefix_root)
179168

169+
if "hipsparselt/library" not in path_str:
180170
return None
181171

182-
except (ValueError, AttributeError):
172+
if path.suffix not in [".co", ".hsaco", ".dat"]:
183173
return None
184174

175+
match = _GFX_ARCH_PATTERN.search(path.name)
176+
if match:
177+
return match.group(0)
178+
179+
return None
180+
185181

186182
class AotritonHandler(DatabaseHandler):
187183
"""Handler for AOTriton kernel image directories.
@@ -219,22 +215,18 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
219215
Returns:
220216
Bundle key (e.g., 'gfx11', 'gfx12_0', 'gfx942') or None.
221217
"""
222-
try:
223-
rel_path = path.relative_to(prefix_root)
224-
path_parts = rel_path.parts
218+
path_str = self._relative_path(path, prefix_root)
219+
path_parts = Path(path_str).parts
225220

226-
for i, part in enumerate(path_parts[:-1]):
227-
if part == "aotriton.images" and i + 1 < len(path_parts):
228-
arch_dir = path_parts[i + 1]
229-
if arch_dir.startswith("amd-gfx"):
230-
raw_name = arch_dir[4:] # strip "amd-" prefix
231-
return self._BUNDLE_MAP.get(raw_name, raw_name)
232-
break
233-
234-
return None
221+
for i, part in enumerate(path_parts[:-1]):
222+
if part == "aotriton.images" and i + 1 < len(path_parts):
223+
arch_dir = path_parts[i + 1]
224+
if arch_dir.startswith("amd-gfx"):
225+
raw_name = arch_dir[4:] # strip "amd-" prefix
226+
return self._BUNDLE_MAP.get(raw_name, raw_name)
227+
break
235228

236-
except (ValueError, AttributeError, IndexError):
237-
return None
229+
return None
238230

239231

240232
class MIOpenHandler(DatabaseHandler):
@@ -261,34 +253,29 @@ def detect(self, path: Path, prefix_root: Path) -> Optional[str]:
261253
262254
Pattern: share/miopen/db/gfx*.{db.txt,fdb.txt,model}
263255
"""
264-
try:
265-
rel_path = path.relative_to(prefix_root)
266-
path_str = rel_path.as_posix()
267-
268-
if "miopen/db" not in path_str:
269-
return None
270-
271-
if not path.is_file():
272-
return None
256+
path_str = self._relative_path(path, prefix_root)
273257

274-
# Match .model, .db.txt, .fdb.txt, .OpenCL.fdb.txt, .HIP.fdb.txt
275-
name = path.name
276-
if not (
277-
name.endswith(".model")
278-
or name.endswith(".db.txt")
279-
or name.endswith(".fdb.txt")
280-
):
281-
return None
282-
283-
match = _MIOPEN_ARCH_PATTERN.search(name)
284-
if match:
285-
return match.group(0)
258+
if "miopen/db" not in path_str:
259+
return None
286260

261+
if not path.is_file():
287262
return None
288263

289-
except (ValueError, AttributeError):
264+
# Match .model, .db.txt, .fdb.txt, .OpenCL.fdb.txt, .HIP.fdb.txt
265+
name = path.name
266+
if not (
267+
name.endswith(".model")
268+
or name.endswith(".db.txt")
269+
or name.endswith(".fdb.txt")
270+
):
290271
return None
291272

273+
match = _MIOPEN_ARCH_PATTERN.search(name)
274+
if match:
275+
return match.group(0)
276+
277+
return None
278+
292279

293280
# Registry of available handlers
294281
AVAILABLE_HANDLERS = {

shared/kpack/python/rocm_kpack/elf/surgery.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,58 @@ def load(cls, path: Path) -> "ElfSurgery":
148148
data = bytearray(path.read_bytes())
149149
return cls(data, path)
150150

151+
@staticmethod
152+
def has_fatbin_section(file_path: Path) -> bool:
153+
"""Fast check for .hip_fatbin section without loading the full binary.
154+
155+
Reads only the ELF header, section header table, and section name
156+
string table — typically a few KB total even for multi-GB binaries.
157+
Returns False for non-ELF files (wrong magic, too small, 32-bit).
158+
Raises on I/O errors or corrupt ELF headers.
159+
"""
160+
with open(file_path, "rb") as f:
161+
ehdr = f.read(ELF64_EHDR_SIZE)
162+
if len(ehdr) < ELF64_EHDR_SIZE:
163+
return False
164+
if ehdr[:4] != b"\x7fELF":
165+
return False
166+
if ehdr[4] != 2: # ELFCLASS64
167+
return False
168+
169+
e_shoff = struct.unpack_from("<Q", ehdr, 40)[0]
170+
e_shentsize = struct.unpack_from("<H", ehdr, 58)[0]
171+
e_shnum = struct.unpack_from("<H", ehdr, 60)[0]
172+
e_shstrndx = struct.unpack_from("<H", ehdr, 62)[0]
173+
174+
if e_shoff == 0 or e_shnum == 0 or e_shstrndx >= e_shnum:
175+
return False
176+
177+
# Read section header table
178+
f.seek(e_shoff)
179+
shtab = f.read(e_shentsize * e_shnum)
180+
if len(shtab) < e_shentsize * e_shnum:
181+
return False
182+
183+
# Read .shstrtab
184+
strtab_entry = e_shstrndx * e_shentsize
185+
sh_offset = struct.unpack_from("<Q", shtab, strtab_entry + 24)[0]
186+
sh_size = struct.unpack_from("<Q", shtab, strtab_entry + 32)[0]
187+
f.seek(sh_offset)
188+
shstrtab = f.read(sh_size)
189+
190+
# Search for ".hip_fatbin" in section names
191+
target = b".hip_fatbin\x00"
192+
if target not in shstrtab:
193+
return False
194+
195+
target_offset = shstrtab.index(target)
196+
for i in range(e_shnum):
197+
sh_name = struct.unpack_from("<I", shtab, i * e_shentsize)[0]
198+
if sh_name == target_offset:
199+
return True
200+
201+
return False
202+
151203
# =========================================================================
152204
# Properties
153205
# =========================================================================

shared/kpack/python/rocm_kpack/kpack_transform.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ def read_kpack_ref_marker(binary_path: Path) -> dict | None:
129129
def is_fat_binary(binary_path: Path) -> bool:
130130
"""Check if a binary contains HIP fat binary sections.
131131
132-
This is a quick check that doesn't require full parsing.
132+
Fast-path: reads only headers (ELF section header table or PE section
133+
headers), not the full binary. Works for both ELF and PE/COFF.
133134
134135
Args:
135136
binary_path: Path to binary
136137
137138
Returns:
138-
True if binary appears to contain device code, False otherwise
139+
True if binary contains a .hip_fatbin (ELF) or .hip_fat (COFF)
140+
section, False otherwise (including non-ELF/non-PE files).
139141
"""
140142
try:
141143
fmt = detect_binary_format(binary_path)
@@ -145,10 +147,8 @@ def is_fat_binary(binary_path: Path) -> bool:
145147
if fmt == "elf":
146148
from .elf.surgery import ElfSurgery
147149

148-
surgery = ElfSurgery.load(binary_path)
149-
return surgery.find_section(".hip_fatbin") is not None
150+
return ElfSurgery.has_fatbin_section(binary_path)
150151
else:
151152
from .coff.surgery import CoffSurgery
152153

153-
surgery = CoffSurgery.load(binary_path)
154-
return surgery.find_section(".hip_fat") is not None
154+
return CoffSurgery.has_fatbin_section(binary_path)

0 commit comments

Comments
 (0)