Skip to content

Commit 803213a

Browse files
fifieldCopilot
andauthored
Add a callable library function to parse_trace.py (#2712)
Co-authored-by: Copilot <[email protected]>
1 parent ed69824 commit 803213a

4 files changed

Lines changed: 745 additions & 107 deletions

File tree

python/utils/parse_trace.py

Lines changed: 191 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
NumTraceTypes = 4
2020
NUM_EVENTS = 8 # number of events we can view per trace
2121

22-
# DEBUG = False
23-
# DEBUG = True
24-
2522

2623
def parse_args():
2724
parser = argparse.ArgumentParser()
@@ -42,8 +39,8 @@ def parse_args():
4239
# Check for valid trace packets data
4340
# 1) if only 1 trace packet
4441
# 2) if first trace packet is all 0's
45-
def check_for_valid_trace(filename, trace_pkts, of):
46-
if DEBUG:
42+
def check_for_valid_trace(filename, trace_pkts, of=None, debug=False):
43+
if debug and of:
4744
print("len(trace_pkts): ", str(len(trace_pkts)), file=of)
4845
print("trace_pkts[0]:", trace_pkts[0], file=of)
4946
if len(trace_pkts) < 2 or trace_pkts[0] == "00000000":
@@ -407,7 +404,15 @@ def lookupEventNameInStr(event, pid, pid_events):
407404
# loc -
408405
# pid_events -
409406
def deactivate_events(
410-
multiples, active_events, timer, cycles, pid, trace_type, loc, pid_events
407+
multiples,
408+
active_events,
409+
timer,
410+
cycles,
411+
pid,
412+
trace_type,
413+
loc,
414+
pid_events,
415+
trace_events,
411416
):
412417
for k in active_events.keys(): # an active event
413418
if cycles > 0 or (cycles == 0 and not k in multiples):
@@ -450,15 +455,15 @@ def activate_event(event, tt, loc, timer, pid, active_events, pid_events, trace_
450455
#
451456
# commands: list (idx = trace type, value = byte_stream_dict)
452457
# byte_stream_dict: dict (key = row,col, value = list of commands)
453-
def convert_commands_to_json(trace_events, commands, pid_events, of):
458+
def convert_commands_to_json(trace_events, commands, pid_events, of=None, debug=False):
454459
# byte_stream_dict for each trace type.
455460
for [tt, byte_stream_dict] in enumerate(commands): # tt = trace type
456461

457462
for loc, command in byte_stream_dict.items(): # row,col with list of commands
458463
timer = 0 # TODO Some way to set this or sync this between trace types and row,col
459464
# timer on each execution is the time for the last execution
460465
# so we by default will increment it by 1 for each event
461-
if DEBUG:
466+
if debug and of:
462467
print(
463468
"tt: "
464469
+ str(tt)
@@ -493,7 +498,7 @@ def convert_commands_to_json(trace_events, commands, pid_events, of):
493498
for i in range(8): # 8 max events at a time
494499
active_events[i] = 0
495500

496-
if DEBUG:
501+
if debug and of:
497502
print("num commands:", len(command), file=of)
498503
for c in command:
499504
t = c["type"]
@@ -512,6 +517,7 @@ def convert_commands_to_json(trace_events, commands, pid_events, of):
512517
tt,
513518
loc,
514519
pid_events,
520+
trace_events,
515521
)
516522
timer = timer + cycles
517523
activate_event(
@@ -541,6 +547,7 @@ def convert_commands_to_json(trace_events, commands, pid_events, of):
541547
tt,
542548
loc,
543549
pid_events,
550+
trace_events,
544551
)
545552
timer = timer + cycles
546553

@@ -574,6 +581,7 @@ def convert_commands_to_json(trace_events, commands, pid_events, of):
574581
tt,
575582
loc,
576583
pid_events,
584+
trace_events,
577585
)
578586
timer = timer + cycles
579587
if len(multiple_list) > 1:
@@ -930,108 +938,184 @@ def align_column_start_index(events, commands):
930938

931939

932940
# ------------------------------------------------------------------------------
933-
# Script execution start - Open trace file and convert to commands
941+
# Library API
934942
# ------------------------------------------------------------------------------
935943

936-
opts = parse_args()
937944

938-
DEBUG = opts.debug
939-
if DEBUG:
940-
print("Debug mode enable\n")
945+
def parse_trace(trace_buffer, mlir_module_str, colshift=None, debug=False):
946+
"""
947+
Parse AIE trace buffer and return trace events as list in Trace Event Format
941948
942-
# set colshift based on optional argument
943-
colshift = int(opts.colshift) if opts.colshift else None
949+
Args:
950+
trace_buffer: numpy array containing trace data (uint32 words)
951+
mlir_module_str: string containing MLIR module with trace configuration
952+
colshift: optional column shift adjustment (int or None for auto-align)
953+
debug: enable debug output (default: False)
944954
945-
try:
946-
with open(opts.input, "r") as f:
947-
# Create array of trace packets
948-
trace_pkts = f.read().split("\n")
949-
except:
950-
print(
951-
"ERROR:", opts.input, "could not be opened. Check for valid trace source file."
952-
)
953-
sys.exit(1)
955+
Returns:
956+
list: trace events in Trace Event Format
957+
"""
958+
959+
# Convert numpy array to list of hex strings (format expected by existing functions)
960+
trace_pkts = []
961+
for word in trace_buffer:
962+
# Convert uint32 to 8-character hex string (lowercase, no '0x' prefix)
963+
hex_str = f"{int(word):08x}"
964+
trace_pkts.append(hex_str)
954965

955-
try:
956-
with open(opts.mlir, "r") as mf:
957-
mlir_module_str = mf.read()
966+
# Parse MLIR to extract event configuration
958967
pid_events = parse_mlir_trace_events(mlir_module_str, colshift)
959-
except Exception as e:
960-
print("ERROR:", opts.mlir, "could not be opened. Check for valid MLIR file.", e)
961-
exit(1)
962-
963-
try:
964-
of = open(opts.output, "w")
965-
except:
966-
print("ERROR:", opts.mlir, "could not be opened. Check for valid output JSON file.")
967-
exit(1)
968-
969-
if DEBUG:
970-
print("DEBUG mode enabled:", file=of)
971-
print("pkt type 0: core tile", file=of)
972-
print("pkt type 1: core mem tile", file=of)
973-
print("pkt type 2: shim tile", file=of)
974-
print("pkt type 3: mem tile", file=of)
975-
print("", file=of)
976-
print("DEBUG: trace_pkts", file=of)
977-
print(trace_pkts, file=of)
978-
print("", file=of)
979-
980-
print("DEBUG: pid events\n", file=of)
981-
# print(pid_events, file=of)
982-
for idx, dict_i in enumerate(pid_events):
983-
print("pkt type", idx, ":", file=of)
984-
for key, value in dict_i.items():
985-
print(key, value, file=of)
986-
print("", file=of)
987-
988-
if not check_for_valid_trace(opts.input, trace_pkts, of):
989-
sys.exit(1)
990-
991-
trimmed_trace_pkts = trim_trace_pkts(trace_pkts)
992-
if DEBUG:
993-
lines_removed = len(trace_pkts) - len(trimmed_trace_pkts)
994-
print("DEBUG: trimmed ", lines_removed, " lines", file=of)
995-
996-
trace_pkts_sorted = trace_pkts_de_interleave(trimmed_trace_pkts)
997-
998-
if DEBUG:
999-
print("DEBUG: trace_pkts_sorted", file=of)
1000-
for idx, dict_i in enumerate(trace_pkts_sorted):
1001-
print("pkt type", idx, ":", file=of)
1002-
for key, value in dict_i.items():
1003-
print(key, value, file=of)
1004-
print("", file=of)
1005-
1006-
byte_streams = convert_to_byte_stream(trace_pkts_sorted)
1007-
1008-
if DEBUG:
1009-
print("DEBUG: byte stream", file=of)
1010-
for idx, dict_i in enumerate(byte_streams):
1011-
print("pkt type", idx, ":", file=of)
1012-
for key, value in dict_i.items():
1013-
print(key, value, file=of)
1014-
print("", file=of)
1015-
1016-
commands_0 = convert_to_commands(byte_streams, False)
1017-
1018-
if DEBUG:
1019-
print("DEBUG: commands_0", file=of)
1020-
for idx, dict_i in enumerate(commands_0):
1021-
print("pkt type", idx, ":", file=of)
1022-
for key, commands in dict_i.items():
1023-
print(key, file=of)
1024-
for i in commands:
1025-
print("\t", i, file=of)
1026-
print("", file=of)
1027-
1028-
if colshift is None:
1029-
pid_events = align_column_start_index(pid_events, commands_0)
1030-
1031-
trace_events = list()
1032-
1033-
setup_trace_metadata(trace_events, pid_events)
1034-
1035-
convert_commands_to_json(trace_events, commands_0, pid_events, of)
1036-
1037-
print(json.dumps(trace_events).replace("'", '"').replace(", {", ",\n{"), file=of)
968+
969+
# Check for valid trace
970+
if not check_for_valid_trace("<numpy_array>", trace_pkts, of=None, debug=debug):
971+
raise ValueError("Invalid trace data: empty or all zeros")
972+
973+
# Trim trailing empty packets
974+
trimmed_trace_pkts = trim_trace_pkts(trace_pkts)
975+
976+
# De-interleave packets by type and location
977+
trace_pkts_sorted = trace_pkts_de_interleave(trimmed_trace_pkts)
978+
979+
# Convert to byte streams
980+
byte_streams = convert_to_byte_stream(trace_pkts_sorted)
981+
982+
# Convert byte streams to command dictionaries
983+
commands = convert_to_commands(byte_streams, False)
984+
985+
# Auto-align column indices if colshift not provided
986+
if colshift is None:
987+
pid_events = align_column_start_index(pid_events, commands)
988+
989+
# Initialize trace events list
990+
trace_events = []
991+
992+
# Setup metadata (process names, thread names, assign PIDs)
993+
setup_trace_metadata(trace_events, pid_events)
994+
995+
# Convert commands to Chrome Trace Event Format
996+
convert_commands_to_json(trace_events, commands, pid_events, of=None, debug=debug)
997+
998+
return trace_events
999+
1000+
1001+
# ------------------------------------------------------------------------------
1002+
# Script execution start - Open trace file and convert to commands
1003+
# ------------------------------------------------------------------------------
1004+
1005+
1006+
def main():
1007+
"""Command-line interface entry point"""
1008+
opts = parse_args()
1009+
1010+
DEBUG = opts.debug
1011+
if DEBUG:
1012+
print("Debug mode enable\n")
1013+
1014+
# set colshift based on optional argument
1015+
colshift = int(opts.colshift) if opts.colshift else None
1016+
1017+
try:
1018+
with open(opts.input, "r") as f:
1019+
# Create array of trace packets
1020+
trace_pkts = f.read().split("\n")
1021+
except Exception:
1022+
print(
1023+
"ERROR:",
1024+
opts.input,
1025+
"could not be opened. Check for valid trace source file.",
1026+
)
1027+
sys.exit(1)
1028+
1029+
try:
1030+
with open(opts.mlir, "r") as mf:
1031+
mlir_module_str = mf.read()
1032+
pid_events = parse_mlir_trace_events(mlir_module_str, colshift)
1033+
except Exception as e:
1034+
print("ERROR:", opts.mlir, "could not be opened. Check for valid MLIR file.", e)
1035+
sys.exit(1)
1036+
1037+
try:
1038+
of = open(opts.output, "w")
1039+
except Exception:
1040+
print(
1041+
"ERROR:",
1042+
opts.mlir,
1043+
"could not be opened. Check for valid output JSON file.",
1044+
)
1045+
sys.exit(1)
1046+
1047+
if DEBUG:
1048+
print("DEBUG mode enabled:", file=of)
1049+
print("pkt type 0: core tile", file=of)
1050+
print("pkt type 1: core mem tile", file=of)
1051+
print("pkt type 2: shim tile", file=of)
1052+
print("pkt type 3: mem tile", file=of)
1053+
print("", file=of)
1054+
print("DEBUG: trace_pkts", file=of)
1055+
print(trace_pkts, file=of)
1056+
print("", file=of)
1057+
1058+
print("DEBUG: pid events\n", file=of)
1059+
# print(pid_events, file=of)
1060+
for idx, dict_i in enumerate(pid_events):
1061+
print("pkt type", idx, ":", file=of)
1062+
for key, value in dict_i.items():
1063+
print(key, value, file=of)
1064+
print("", file=of)
1065+
1066+
if not check_for_valid_trace(opts.input, trace_pkts, of, DEBUG):
1067+
sys.exit(1)
1068+
1069+
trimmed_trace_pkts = trim_trace_pkts(trace_pkts)
1070+
if DEBUG:
1071+
lines_removed = len(trace_pkts) - len(trimmed_trace_pkts)
1072+
print("DEBUG: trimmed ", lines_removed, " lines", file=of)
1073+
1074+
trace_pkts_sorted = trace_pkts_de_interleave(trimmed_trace_pkts)
1075+
1076+
if DEBUG:
1077+
print("DEBUG: trace_pkts_sorted", file=of)
1078+
for idx, dict_i in enumerate(trace_pkts_sorted):
1079+
print("pkt type", idx, ":", file=of)
1080+
for key, value in dict_i.items():
1081+
print(key, value, file=of)
1082+
print("", file=of)
1083+
1084+
byte_streams = convert_to_byte_stream(trace_pkts_sorted)
1085+
1086+
if DEBUG:
1087+
print("DEBUG: byte stream", file=of)
1088+
for idx, dict_i in enumerate(byte_streams):
1089+
print("pkt type", idx, ":", file=of)
1090+
for key, value in dict_i.items():
1091+
print(key, value, file=of)
1092+
print("", file=of)
1093+
1094+
commands_0 = convert_to_commands(byte_streams, False)
1095+
1096+
if DEBUG:
1097+
print("DEBUG: commands_0", file=of)
1098+
for idx, dict_i in enumerate(commands_0):
1099+
print("pkt type", idx, ":", file=of)
1100+
for key, commands in dict_i.items():
1101+
print(key, file=of)
1102+
for i in commands:
1103+
print("\t", i, file=of)
1104+
print("", file=of)
1105+
1106+
if colshift is None:
1107+
pid_events = align_column_start_index(pid_events, commands_0)
1108+
1109+
trace_events = list()
1110+
1111+
setup_trace_metadata(trace_events, pid_events)
1112+
1113+
convert_commands_to_json(trace_events, commands_0, pid_events, of, DEBUG)
1114+
1115+
print(json.dumps(trace_events).replace("'", '"').replace(", {", ",\n{"), file=of)
1116+
1117+
of.close()
1118+
1119+
1120+
if __name__ == "__main__":
1121+
main()

0 commit comments

Comments
 (0)