Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions graph_net/paddle/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def set_seed(random_seed):


def get_hardward_name(args):
if args.device == "cuda":
if test_compiler_util.is_gpu_device(args.device):
hardware = paddle.device.cuda.get_device_name(0)
elif args.device == "cpu":
hardware = platform.processor()
Expand Down Expand Up @@ -64,15 +64,15 @@ def get_synchronizer_func(args):
return paddle.device.synchronize


def get_model(args):
def get_model(model_path):
model_class = load_class_from_file(
f"{args.model_path}/model.py", class_name="GraphModule"
f"{model_path}/model.py", class_name="GraphModule"
)
return model_class()


def get_input_dict(args):
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
def get_input_dict(model_path):
inputs_params = utils.load_converted_from_text(f"{model_path}")
params = inputs_params["weight_info"]
inputs = inputs_params["input_info"]

Expand All @@ -81,8 +81,8 @@ def get_input_dict(args):
return state_dict


def get_input_spec(args):
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
def get_input_spec(model_path):
inputs_params_list = utils.load_converted_list_from_text(f"{model_path}")
input_spec = [None] * len(inputs_params_list)
for i, v in enumerate(inputs_params_list):
dtype = v["info"]["dtype"]
Expand All @@ -94,7 +94,7 @@ def get_input_spec(args):
def get_compiled_model(args, model):
if args.compiler == "nope":
return model
input_spec = get_input_spec(args)
input_spec = get_input_spec(args.model_path)
build_strategy = paddle.static.BuildStrategy()
compiled_model = paddle.jit.to_static(
model,
Expand All @@ -110,7 +110,7 @@ def get_compiled_model(args, model):
def get_static_model(args, model):
static_model = paddle.jit.to_static(
model,
input_spec=get_input_spec(args),
input_spec=get_input_spec(args.model_path),
full_graph=True,
backend=None,
)
Expand Down Expand Up @@ -138,7 +138,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
flush=True,
)

if "cuda" in args.device:
if test_compiler_util.is_gpu_device(args.device):
"""
Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings,
With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench
Expand Down Expand Up @@ -249,8 +249,8 @@ def transfer_to_float(origin_outputs):

def test_single_model(args):
synchronizer_func = get_synchronizer_func(args)
input_dict = get_input_dict(args)
model = get_model(args)
input_dict = get_input_dict(args.model_path)
model = get_model(args.model_path)
model.eval()

test_compiler_util.print_basic_config(
Expand All @@ -259,6 +259,7 @@ def test_single_model(args):

# Run on eager mode
eager_success = False
eager_time_stats = {}
try:
print("Run model in eager mode.", file=sys.stderr, flush=True)
static_model = get_static_model(args, model)
Expand All @@ -275,6 +276,7 @@ def test_single_model(args):

# Run on compiling mode
compiled_success = False
compiled_time_stats = {}
try:
print("Run model in compiled mode.", file=sys.stderr, flush=True)
compiled_model = get_compiled_model(args, model)
Expand All @@ -293,9 +295,9 @@ def test_single_model(args):
if eager_success and compiled_success:
check_outputs(args, expected_out, compiled_out)

test_compiler_util.print_times_and_speedup(
args, eager_time_stats, compiled_time_stats
)
test_compiler_util.print_times_and_speedup(
args, eager_time_stats, compiled_time_stats
)


def get_cmp_equal(expected_out, compiled_out):
Expand Down Expand Up @@ -366,20 +368,12 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):


def test_multi_models(args):
test_samples = None
if args.allow_list is not None:
assert os.path.isfile(args.allow_list)
graphnet_root = path_utils.get_graphnet_root()
print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True)
verified_samples = []
with open(args.verified_samples_list_path, "r") as f:
for line in f.readlines():
test_samples.append(os.path.join(graphnet_root, line.strip()))
test_samples = test_compiler_util.get_allow_samples(args.allow_list)

sample_idx = 0
failed_samples = []
for model_path in path_utils.get_recursively_model_path(args.model_path):
if verified_samples is None or os.path.abspath(model_path) in verified_samples:
if test_samples is None or os.path.abspath(model_path) in test_samples:
print(
f"[{sample_idx}] test_compiler, model_path: {model_path}",
file=sys.stderr,
Expand Down Expand Up @@ -415,11 +409,24 @@ def test_multi_models(args):
def main(args):
assert os.path.isdir(args.model_path)
assert args.compiler in {"cinn", "nope"}
assert args.device in ["cuda", "dcu", "cpu"]

initalize_seed = 123
set_seed(random_seed=initalize_seed)

if path_utils.is_single_model_dir(args.model_path):
if paddle.device.is_compiled_with_cuda():
device_id = int(paddle.device.get_device().split(":")[-1])
device_count = paddle.device.cuda.device_count()
gpu_util, mem_util = test_compiler_util.get_device_utilization(
device_id, device_count, get_synchronizer_func(args)
)
if gpu_util is not None and mem_util is not None:
print(
f"Device status: gpu_id {device_id}, gpu_util {gpu_util:.2f}%, mem_util {mem_util:.2f}%",
file=sys.stderr,
flush=True,
)
test_single_model(args)
else:
test_multi_models(args)
Expand Down
5 changes: 3 additions & 2 deletions graph_net/paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ def replay_tensor(info):
else:
if mean is not None and std is not None:
tensor = paddle.empty(shape=shape, dtype=dtype)
paddle.nn.init.trunc_normal_(
tensor=tensor, mean=mean, std=std, a=min_val, b=max_val
initializer = paddle.nn.initializer.TruncatedNormal(
mean=mean, std=std, a=min_val, b=max_val
)
initializer(tensor)
return tensor.to(dtype).to(device)
else:
return (
Expand Down
153 changes: 139 additions & 14 deletions graph_net/test_compiler_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import sys
import json
import time
import subprocess
import shutil
import numpy as np
from dataclasses import dataclass
from contextlib import contextmanager

from graph_net import path_utils


@dataclass
class DurationBox:
Expand All @@ -23,6 +27,103 @@ def naive_timer(duration_box, synchronizer_func):
duration_box.value = (end - start) * 1000 # Store in milliseconds


def is_gpu_device(device):
return "cuda" in device or "dcu" in device


def get_device_utilization(device_id, device_count, synchronizer_func):
current_pid = os.getpid()

if shutil.which("nvidia-smi"):
try:
cuda_devices_str = os.getenv("CUDA_VISIBLE_DEVICES", "")
if cuda_devices_str != "":
cuda_devices = list(map(int, cuda_devices_str.split(",")))
else:
cuda_devices = list(range(device_count))
selected_gpu_id = cuda_devices[device_id]

print(
f"Check the status of GPU {selected_gpu_id} for 5 times.",
file=sys.stderr,
flush=True,
)
selected_gpu_uuid, max_gpu_util, max_mem_util = None, 0.0, 0.0
for i in range(5):
synchronizer_func()
time.sleep(1)

output = (
subprocess.check_output(
[
"nvidia-smi",
f"--query-gpu=index,gpu_uuid,utilization.gpu,memory.used,memory.total",
"--format=csv,noheader,nounits",
]
)
.decode()
.strip()
)
for line in output.split("\n"):
if line.strip():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些代码缩进层级太深,一般都意味着可维护性问题。下次可以多用列表解析(list comprehension)

(
gpu_id,
selected_gpu_uuid,
gpu_util,
used_mem,
mem_total,
) = line.split(", ")
if int(gpu_id) == selected_gpu_id:
break

gpu_util = float(gpu_util)
mem_util = float(used_mem) * 100 / float(mem_total)
print(
f"- gpu_id: {selected_gpu_id}, gpu_uuid: {selected_gpu_uuid}, gpu_util: {gpu_util:.2f}%, used_mem: {used_mem}, mem_total: {mem_total}",
file=sys.stderr,
flush=True,
)

max_gpu_util = gpu_util if gpu_util > max_gpu_util else max_gpu_util
max_mem_util = mem_util if mem_util > max_mem_util else max_mem_util

other_tasks = []
output = (
subprocess.check_output(
[
"nvidia-smi",
f"--query-compute-apps=gpu_uuid,pid,used_memory",
"--format=csv,noheader,nounits",
]
)
.decode()
.strip()
)
for line in output.split("\n"):
if line.strip():
gpu_uuid, pid, used_memory = line.split(", ")
if gpu_uuid == selected_gpu_uuid and int(pid) != current_pid:
other_tasks.append(line)
# Note: in docker container, the current_pid maybe different from that captured by nvidia-smi.
print(
f"Note: There are {len(other_tasks)} tasks running on GPU {selected_gpu_id} (current_pid:{current_pid}).",
file=sys.stderr,
flush=True,
)
for task in other_tasks:
gpu_uuid, pid, used_memory = task.split(", ")
print(
f"- gpu_uuid:{gpu_uuid}, pid:{pid}, used_memory:{used_memory}",
file=sys.stderr,
flush=True,
)
return max_gpu_util, max_mem_util
except subprocess.CalledProcessError:
pass

return None, None


def get_timing_stats(elapsed_times):
stats = {
"mean": float(f"{np.mean(elapsed_times):.6g}"),
Expand Down Expand Up @@ -75,24 +176,33 @@ def print_basic_config(args, hardware_name, compile_framework_version):
)


def print_running_status(args, eager_success, compiled_success):
def print_running_status(args, eager_success, compiled_success=None):
def convert_to_str(b):
return "success" if b else "failed"

print_with_log_prompt(
"[Result][status]",
f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}",
args.log_prompt,
)
if compiled_success is not None:
print_with_log_prompt(
"[Result][status]",
f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}",
args.log_prompt,
)
else:
print_with_log_prompt(
"[Result][status]",
f"eager:{convert_to_str(eager_success)}",
args.log_prompt,
)


def print_times_and_speedup(args, eager_stats, compiled_stats):
print_with_log_prompt(
"[Performance][eager]:", json.dumps(eager_stats), args.log_prompt
)
print_with_log_prompt(
"[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt
)
if not eager_stats:
print_with_log_prompt(
"[Performance][eager]:", json.dumps(eager_stats), args.log_prompt
)
if not compiled_stats:
print_with_log_prompt(
"[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt
)

e2e_speedup = 0
gpu_speedup = 0
Expand All @@ -103,7 +213,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats):
if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0:
e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms

if "cuda" in args.device:
if is_gpu_device(args.device):
eager_gpu_time_ms = eager_stats.get("gpu", {}).get("mean", 0)
compiled_gpu_time_ms = compiled_stats.get("gpu", {}).get("mean", 0)

Expand All @@ -113,7 +223,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats):
if e2e_speedup > 0:
print_with_log_prompt("[Speedup][e2e]:", f"{e2e_speedup:.5f}", args.log_prompt)

if "cuda" in args.device and gpu_speedup > 0:
if is_gpu_device(args.device) and gpu_speedup > 0:
print_with_log_prompt("[Speedup][gpu]:", f"{gpu_speedup:.5f}", args.log_prompt)


Expand Down Expand Up @@ -224,3 +334,18 @@ def check_allclose(
compiled_out=compiled_out,
**kwargs,
)


def get_allow_samples(allow_list):
if allow_list is None:
return None

assert os.path.isfile(allow_list), f"{allow_list} is not a regular file."
graphnet_root = path_utils.get_graphnet_root()
print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True)
test_samples = []
with open(allow_list, "r") as f:
for line in f.readlines():
test_samples.append(os.path.join(graphnet_root, line.strip()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数怎么总是返回None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复


return test_samples