diff --git a/graph_net/paddle/test_compiler.py b/graph_net/paddle/test_compiler.py index e211c913f..4d9f4fc22 100644 --- a/graph_net/paddle/test_compiler.py +++ b/graph_net/paddle/test_compiler.py @@ -25,7 +25,7 @@ def set_seed(random_seed): def get_hardward_name(args): - if args.device == "cuda": + if test_compiler_util.is_gpu_device(args.device): hardware = paddle.device.cuda.get_device_name(0) elif args.device == "cpu": hardware = platform.processor() @@ -64,15 +64,15 @@ def get_synchronizer_func(args): return paddle.device.synchronize -def get_model(args): +def get_model(model_path): model_class = load_class_from_file( - f"{args.model_path}/model.py", class_name="GraphModule" + f"{model_path}/model.py", class_name="GraphModule" ) return model_class() -def get_input_dict(args): - inputs_params = utils.load_converted_from_text(f"{args.model_path}") +def get_input_dict(model_path): + inputs_params = utils.load_converted_from_text(f"{model_path}") params = inputs_params["weight_info"] inputs = inputs_params["input_info"] @@ -81,8 +81,8 @@ def get_input_dict(args): return state_dict -def get_input_spec(args): - inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}") +def get_input_spec(model_path): + inputs_params_list = utils.load_converted_list_from_text(f"{model_path}") input_spec = [None] * len(inputs_params_list) for i, v in enumerate(inputs_params_list): dtype = v["info"]["dtype"] @@ -94,7 +94,7 @@ def get_input_spec(args): def get_compiled_model(args, model): if args.compiler == "nope": return model - input_spec = get_input_spec(args) + input_spec = get_input_spec(args.model_path) build_strategy = paddle.static.BuildStrategy() compiled_model = paddle.jit.to_static( model, @@ -110,7 +110,7 @@ def get_compiled_model(args, model): def get_static_model(args, model): static_model = paddle.jit.to_static( model, - input_spec=get_input_spec(args), + input_spec=get_input_spec(args.model_path), full_graph=True, backend=None, ) @@ -138,7 +138,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False): flush=True, ) - if "cuda" in args.device: + if test_compiler_util.is_gpu_device(args.device): """ Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings, With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench @@ -249,8 +249,8 @@ def transfer_to_float(origin_outputs): def test_single_model(args): synchronizer_func = get_synchronizer_func(args) - input_dict = get_input_dict(args) - model = get_model(args) + input_dict = get_input_dict(args.model_path) + model = get_model(args.model_path) model.eval() test_compiler_util.print_basic_config( @@ -259,6 +259,7 @@ def test_single_model(args): # Run on eager mode eager_success = False + eager_time_stats = {} try: print("Run model in eager mode.", file=sys.stderr, flush=True) static_model = get_static_model(args, model) @@ -275,6 +276,7 @@ def test_single_model(args): # Run on compiling mode compiled_success = False + compiled_time_stats = {} try: print("Run model in compiled mode.", file=sys.stderr, flush=True) compiled_model = get_compiled_model(args, model) @@ -293,9 +295,9 @@ def test_single_model(args): if eager_success and compiled_success: check_outputs(args, expected_out, compiled_out) - test_compiler_util.print_times_and_speedup( - args, eager_time_stats, compiled_time_stats - ) + test_compiler_util.print_times_and_speedup( + args, eager_time_stats, compiled_time_stats + ) def get_cmp_equal(expected_out, compiled_out): @@ -366,20 +368,12 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol): def test_multi_models(args): - test_samples = None - if args.allow_list is not None: - assert os.path.isfile(args.allow_list) - graphnet_root = path_utils.get_graphnet_root() - print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True) - verified_samples = [] - with open(args.verified_samples_list_path, "r") as f: - for line in f.readlines(): - test_samples.append(os.path.join(graphnet_root, line.strip())) + test_samples = test_compiler_util.get_allow_samples(args.allow_list) sample_idx = 0 failed_samples = [] for model_path in path_utils.get_recursively_model_path(args.model_path): - if verified_samples is None or os.path.abspath(model_path) in verified_samples: + if test_samples is None or os.path.abspath(model_path) in test_samples: print( f"[{sample_idx}] test_compiler, model_path: {model_path}", file=sys.stderr, @@ -415,11 +409,24 @@ def test_multi_models(args): def main(args): assert os.path.isdir(args.model_path) assert args.compiler in {"cinn", "nope"} + assert args.device in ["cuda", "dcu", "cpu"] initalize_seed = 123 set_seed(random_seed=initalize_seed) if path_utils.is_single_model_dir(args.model_path): + if paddle.device.is_compiled_with_cuda(): + device_id = int(paddle.device.get_device().split(":")[-1]) + device_count = paddle.device.cuda.device_count() + gpu_util, mem_util = test_compiler_util.get_device_utilization( + device_id, device_count, get_synchronizer_func(args) + ) + if gpu_util is not None and mem_util is not None: + print( + f"Device status: gpu_id {device_id}, gpu_util {gpu_util:.2f}%, mem_util {mem_util:.2f}%", + file=sys.stderr, + flush=True, + ) test_single_model(args) else: test_multi_models(args) diff --git a/graph_net/paddle/utils.py b/graph_net/paddle/utils.py index 2148cecef..ee4c9bce5 100644 --- a/graph_net/paddle/utils.py +++ b/graph_net/paddle/utils.py @@ -214,9 +214,10 @@ def replay_tensor(info): else: if mean is not None and std is not None: tensor = paddle.empty(shape=shape, dtype=dtype) - paddle.nn.init.trunc_normal_( - tensor=tensor, mean=mean, std=std, a=min_val, b=max_val + initializer = paddle.nn.initializer.TruncatedNormal( + mean=mean, std=std, a=min_val, b=max_val ) + initializer(tensor) return tensor.to(dtype).to(device) else: return ( diff --git a/graph_net/test_compiler_util.py b/graph_net/test_compiler_util.py index 84b0aba21..02069d8b5 100644 --- a/graph_net/test_compiler_util.py +++ b/graph_net/test_compiler_util.py @@ -3,10 +3,14 @@ import sys import json import time +import subprocess +import shutil import numpy as np from dataclasses import dataclass from contextlib import contextmanager +from graph_net import path_utils + @dataclass class DurationBox: @@ -23,6 +27,103 @@ def naive_timer(duration_box, synchronizer_func): duration_box.value = (end - start) * 1000 # Store in milliseconds +def is_gpu_device(device): + return "cuda" in device or "dcu" in device + + +def get_device_utilization(device_id, device_count, synchronizer_func): + current_pid = os.getpid() + + if shutil.which("nvidia-smi"): + try: + cuda_devices_str = os.getenv("CUDA_VISIBLE_DEVICES", "") + if cuda_devices_str != "": + cuda_devices = list(map(int, cuda_devices_str.split(","))) + else: + cuda_devices = list(range(device_count)) + selected_gpu_id = cuda_devices[device_id] + + print( + f"Check the status of GPU {selected_gpu_id} for 5 times.", + file=sys.stderr, + flush=True, + ) + selected_gpu_uuid, max_gpu_util, max_mem_util = None, 0.0, 0.0 + for i in range(5): + synchronizer_func() + time.sleep(1) + + output = ( + subprocess.check_output( + [ + "nvidia-smi", + f"--query-gpu=index,gpu_uuid,utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + ] + ) + .decode() + .strip() + ) + for line in output.split("\n"): + if line.strip(): + ( + gpu_id, + selected_gpu_uuid, + gpu_util, + used_mem, + mem_total, + ) = line.split(", ") + if int(gpu_id) == selected_gpu_id: + break + + gpu_util = float(gpu_util) + mem_util = float(used_mem) * 100 / float(mem_total) + print( + f"- gpu_id: {selected_gpu_id}, gpu_uuid: {selected_gpu_uuid}, gpu_util: {gpu_util:.2f}%, used_mem: {used_mem}, mem_total: {mem_total}", + file=sys.stderr, + flush=True, + ) + + max_gpu_util = gpu_util if gpu_util > max_gpu_util else max_gpu_util + max_mem_util = mem_util if mem_util > max_mem_util else max_mem_util + + other_tasks = [] + output = ( + subprocess.check_output( + [ + "nvidia-smi", + f"--query-compute-apps=gpu_uuid,pid,used_memory", + "--format=csv,noheader,nounits", + ] + ) + .decode() + .strip() + ) + for line in output.split("\n"): + if line.strip(): + gpu_uuid, pid, used_memory = line.split(", ") + if gpu_uuid == selected_gpu_uuid and int(pid) != current_pid: + other_tasks.append(line) + # Note: in docker container, the current_pid maybe different from that captured by nvidia-smi. + print( + f"Note: There are {len(other_tasks)} tasks running on GPU {selected_gpu_id} (current_pid:{current_pid}).", + file=sys.stderr, + flush=True, + ) + for task in other_tasks: + gpu_uuid, pid, used_memory = task.split(", ") + print( + f"- gpu_uuid:{gpu_uuid}, pid:{pid}, used_memory:{used_memory}", + file=sys.stderr, + flush=True, + ) + return max_gpu_util, max_mem_util + except subprocess.CalledProcessError: + pass + + return None, None + + def get_timing_stats(elapsed_times): stats = { "mean": float(f"{np.mean(elapsed_times):.6g}"), @@ -75,24 +176,33 @@ def print_basic_config(args, hardware_name, compile_framework_version): ) -def print_running_status(args, eager_success, compiled_success): +def print_running_status(args, eager_success, compiled_success=None): def convert_to_str(b): return "success" if b else "failed" - print_with_log_prompt( - "[Result][status]", - f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}", - args.log_prompt, - ) + if compiled_success is not None: + print_with_log_prompt( + "[Result][status]", + f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}", + args.log_prompt, + ) + else: + print_with_log_prompt( + "[Result][status]", + f"eager:{convert_to_str(eager_success)}", + args.log_prompt, + ) def print_times_and_speedup(args, eager_stats, compiled_stats): - print_with_log_prompt( - "[Performance][eager]:", json.dumps(eager_stats), args.log_prompt - ) - print_with_log_prompt( - "[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt - ) + if not eager_stats: + print_with_log_prompt( + "[Performance][eager]:", json.dumps(eager_stats), args.log_prompt + ) + if not compiled_stats: + print_with_log_prompt( + "[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt + ) e2e_speedup = 0 gpu_speedup = 0 @@ -103,7 +213,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats): if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0: e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms - if "cuda" in args.device: + if is_gpu_device(args.device): eager_gpu_time_ms = eager_stats.get("gpu", {}).get("mean", 0) compiled_gpu_time_ms = compiled_stats.get("gpu", {}).get("mean", 0) @@ -113,7 +223,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats): if e2e_speedup > 0: print_with_log_prompt("[Speedup][e2e]:", f"{e2e_speedup:.5f}", args.log_prompt) - if "cuda" in args.device and gpu_speedup > 0: + if is_gpu_device(args.device) and gpu_speedup > 0: print_with_log_prompt("[Speedup][gpu]:", f"{gpu_speedup:.5f}", args.log_prompt) @@ -224,3 +334,18 @@ def check_allclose( compiled_out=compiled_out, **kwargs, ) + + +def get_allow_samples(allow_list): + if allow_list is None: + return None + + assert os.path.isfile(allow_list), f"{allow_list} is not a regular file." + graphnet_root = path_utils.get_graphnet_root() + print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True) + test_samples = [] + with open(allow_list, "r") as f: + for line in f.readlines(): + test_samples.append(os.path.join(graphnet_root, line.strip())) + + return test_samples