diff --git a/test/profiler/test_device_spec.py b/test/profiler/test_device_spec.py new file mode 100644 index 0000000000..1ede428fe0 --- /dev/null +++ b/test/profiler/test_device_spec.py @@ -0,0 +1,70 @@ +import pytest + +cuda_driver = pytest.importorskip( + "triton.runtime.driver", reason="requires triton cuda driver module" +) +import itertools + +import torch +from utils import patch_device + +from torchao.profiler.device_spec import ( + _AVAILABLE_GPU_SPECS, + CUDADeviceSpec, + get_chip_name, +) + +# -------------------- Device Spec Tests ------------------- # +DEVICE_NAMES = ["h100 sxm", "a100", "nvidia geforce rtx 4090"] +DTYPES = [torch.float32, torch.bfloat16, torch.float16] +USE_TENSORCORES = [True, False] +DEVICE_CONFIGS = itertools.product(DEVICE_NAMES, DTYPES, USE_TENSORCORES) + + +@pytest.mark.parametrize( + "device_name, dtype, use_tensorcores", DEVICE_CONFIGS, ids=lambda x: str(x) +) +def test_device_spec(device_name, dtype, use_tensorcores): + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=dtype, use_tensorcores=use_tensorcores) + if dtype == torch.float32 and use_tensorcores: + dtype = "tfloat32" + chip_name = get_chip_name(device_name) + expected_flops = _AVAILABLE_GPU_SPECS[chip_name][dtype] + assert device_spec.flops_per_s == expected_flops + assert device_spec.flops_by_dtype[dtype] == expected_flops + assert ( + device_spec.roofline_balancepoint == expected_flops / device_spec.bandwidth + ) + + with pytest.raises(AssertionError): + device_spec.flops_per_s = None + print(device_spec.roofline_balancepoint) + # Prevent setting attributes not in named fields to guard against user error + with pytest.raises(AttributeError): + device_spec.FLOPs = None + + +def test_empty_device_spec(): + device_name = "fake device" + with patch_device(device_name): + with pytest.raises(AssertionError): + _ = CUDADeviceSpec() + + # Ok to instantiate as long as fields are filled + _ = CUDADeviceSpec( + name=device_name, + flops_per_s=1.0, + bandwidth=1.0, + dtype=torch.float32, + use_tensorcores=True, + ) + device_name = DEVICE_NAMES[0] + + with patch_device(device_name): + # All critical fields will be auto-filled except for dtype (and vram, but vram is not used for downstream calcs atm) + _ = CUDADeviceSpec(dtype=torch.float32) + + # No dtype specified + with pytest.raises(AssertionError): + _ = CUDADeviceSpec() diff --git a/test/profiler/test_performance_counter.py b/test/profiler/test_performance_counter.py new file mode 100644 index 0000000000..2cd1a33581 --- /dev/null +++ b/test/profiler/test_performance_counter.py @@ -0,0 +1,530 @@ +import pytest + +# Skip if transformers is not installed +transformers = pytest.importorskip("transformers") +LlamaConfig = transformers.models.llama.modeling_llama.LlamaConfig +LlamaForCausalLM = transformers.models.llama.modeling_llama.LlamaForCausalLM + +import json +import tempfile +import time +import unittest +from contextlib import contextmanager +from dataclasses import asdict +from pathlib import Path +from typing import Union +from unittest.mock import patch + +import torch +from parameterized import parameterized_class +from utils import ( + PerfCounterManagerTestConfig, + PerfCounterResult, + PerfCounterTestConfig, + PerfStatsTestConfig, + attn_io_check, + ffn_io_check, + get_leaf_nodes, + get_test_name, + patch_device, + qkv_proj_io_check, +) + +from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec +from torchao.profiler.performance_counter import ( + CUDAPerformanceTimer, + PerformanceCounterMode, + PerformanceStats, + PerformanceTimer, + TransformerPerformanceCounter, +) +from torchao.utils import TORCH_VERSION_AFTER_2_5 + +# ------------------- PerformanceCounter Tests ------------------- # + +PERFCOUNTER_TEST_CONFIGS = [ + PerfCounterTestConfig( + name="3.5B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 2, + hidden_size=4096 // 2, + intermediate_size=11008 // 2, + num_attention_heads=32 // 2, + vocab_size=32000 // 2, + ), + PerfCounterTestConfig( + name="1.25B", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=32 // 4, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4, + ), + PerfCounterTestConfig( + name="tiny", + batch_size=1, + seqlen=128, + dtype=torch.float16, + num_hidden_layers=1, + hidden_size=4096 // 4, + intermediate_size=11008 // 4, + num_attention_heads=32 // 4, + vocab_size=32000 // 4, + ), +] + + +@unittest.skipIf( + not TORCH_VERSION_AFTER_2_5, "PerformanceCounter requires torch >= 2.5+." +) +@unittest.skipIf(not torch.cuda.is_available(), "PerformanceCounter requires CUDA") +@parameterized_class( + [asdict(cfg) for cfg in PERFCOUNTER_TEST_CONFIGS], class_name_func=get_test_name +) +class PerformanceCounterTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + model_cfg = LlamaConfig( + num_hidden_layers=cls.num_hidden_layers, + hidden_size=cls.hidden_size, + intermediate_size=cls.intermediate_size, + num_attention_heads=cls.num_attention_heads, + vocab_size=cls.vocab_size, + ) + + # Note we set some options manually since the model doesn't seem to be initialized correctly + # when these options are set in LlamaConfig + model_cfg._attn_implementation = "sdpa" + cls.model = model = LlamaForCausalLM(model_cfg).to(cls.dtype).to("cuda") + cls.model_config = model.config + cls.element_size = cls.dtype.itemsize + + input_ids = torch.randint( + 0, model.config.vocab_size, (cls.batch_size, cls.seqlen), device="cuda" + ) + with torch.no_grad(): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION + ): + with PerformanceCounterMode() as perf_counter: + _ = model(input_ids) + cls.perf_counter = perf_counter + cls.summary_flops = perf_counter.get_summary_flop_counts() + cls.summary_io = perf_counter.get_summary_io_counts() + cls.flops_by_op = perf_counter.get_flop_counts() + cls.io_by_op = perf_counter.get_io_counts() + + def test_qkv_proj(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size + + assert len(self.summary_flops) == len(self.summary_io) + assert self.summary_flops.keys() == self.summary_io.keys() + + # Attn Projections + for k in ["q_proj", "k_proj", "v_proj"]: + # Flops check + proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) + assert len(proj_keys) == self.model.config.num_hidden_layers + expected_flops = ( + 2 + * batch_size + * seqlen + * self.model_config.hidden_size + * self.model_config.hidden_size + ) + assert expected_flops == self.summary_flops[proj_keys[0]] + + # io check + expected_size = qkv_proj_io_check( + self.model_config, batch_size, seqlen, element_size + ) + assert expected_size == self.summary_io[proj_keys[0]] + + def test_attn(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size + model_config = self.model.config + + attention_keys = get_leaf_nodes(self.summary_flops.keys(), "self_attn") + for k in attention_keys: + flops = self.flops_by_op[k] + io_movement = self.io_by_op[k] + for op, count in flops.items(): + if "attention" in op.__name__: + expected_flops = ( + 2 * 2 * batch_size * seqlen * seqlen * model_config.hidden_size + ) + assert expected_flops == count + for op, count in io_movement.items(): + if "attention" in op.__name__: + # Check approx equal due to other small artifacts returned by sdpa.mem_efficient_attention + # See #https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L867 + # Check within 100 bytes + expected_size = attn_io_check( + model_config, batch_size, seqlen, element_size + ) + assert abs(expected_size - count) < 100 + + def test_ffn(self): + batch_size, seqlen = self.batch_size, self.seqlen + element_size = self.element_size + + for k in ["up_proj", "gate_proj", "down_proj"]: + proj_keys = get_leaf_nodes(self.summary_flops.keys(), k) + assert len(proj_keys) == self.model.config.num_hidden_layers + expected_flops = ( + 2 + * batch_size + * seqlen + * self.model_config.hidden_size + * self.model_config.intermediate_size + ) + assert expected_flops == self.summary_flops[proj_keys[0]] + + # io check + expected_size = ffn_io_check( + self.model_config, batch_size, seqlen, element_size, k + ) + assert expected_size == self.summary_io[proj_keys[0]] + + +# ------------------- PerformanceStats Tests ------------------- # + +PERFSTATS_TEST_CONFIGS = [ + PerfStatsTestConfig( + label="with_device", + num_tokens=128, + latency=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=1e9, + device_flops_per_s=23e9, + ), + PerfStatsTestConfig( + label="no_device", + num_tokens=128, + latency=0.1, + total_flops=123e9, + total_io=123e6, + flops_summary={"a": 234e12, "b": 345e9}, + io_summary={"a": 1, "b": 2}, + flop_counts={"a": 234e12, "b": 345e9}, + io_counts={"a": 1, "b": 2}, + device_bandwidth=None, + device_flops_per_s=None, + ), +] + + +@pytest.mark.parametrize("cfg", PERFSTATS_TEST_CONFIGS, ids=lambda cfg: cfg.label) +def test_performance_stats(cfg: PerfStatsTestConfig): + stats = PerformanceStats(**asdict(cfg)) + num_tokens = cfg.num_tokens + latency = cfg.latency + total_flops = cfg.total_flops + total_io = cfg.total_io + device_bandwidth = cfg.device_bandwidth + device_flops_per_s = cfg.device_flops_per_s + + # Test derived metrics + assert stats.token_throughput == num_tokens / latency + assert stats.achieved_bandwidth == total_io / latency + assert stats.achieved_flops_per_s == total_flops / latency + if device_bandwidth is not None: + assert ( + stats.bandwidth_utilization == stats.achieved_bandwidth / device_bandwidth + ) + assert stats.theoretical_io_latency == total_io / device_bandwidth + else: + assert stats.bandwidth_utilization is None + assert stats.theoretical_io_latency is None + if device_flops_per_s is not None: + assert ( + stats.flops_utilization == stats.achieved_flops_per_s / device_flops_per_s + ) + assert stats.theoretical_compute_latency == total_flops / device_flops_per_s + else: + assert stats.flops_utilization is None + assert stats.theoretical_compute_latency is None + + # Test str - stats should be formatted to closest power of 10 ** 3 with 2 decimal places of precision + stats_str = str(stats) + + # Base Stats + expected_io_str = ".12 GB" + expected_flops_str = ".12 TFLOPs" + assert expected_io_str in stats_str + assert expected_flops_str in stats_str + + # Derived Stats + expected_io_throughput_str = "1.23 GB/s" + expected_flops_throughput_str = "1.23 TFLOPs/s" + assert expected_io_throughput_str in stats_str + assert expected_flops_throughput_str in stats_str + + # Utilization Stats + if device_bandwidth is not None: + expected_bandwidth_utilization_str = ( + f"{stats.achieved_bandwidth / device_bandwidth:.4f}" + ) + expected_io_latency_str = f"{stats.theoretical_io_latency:.2f} s" + assert expected_bandwidth_utilization_str in stats_str + assert expected_io_latency_str in stats_str + + if device_flops_per_s is not None: + expected_flops_utilization_str = ( + f"{stats.achieved_flops_per_s / device_flops_per_s:.4f}" + ) + expected_compute_latency_str = f"{stats.theoretical_compute_latency:.2f} s" + assert expected_flops_utilization_str in stats_str + assert expected_compute_latency_str in stats_str + + +# ------------------- TransformerPerformanceCounter Tests ------------------- # + +PERFCOUNTERMANAGER_TEST_CONFIGS = [ + PerfCounterManagerTestConfig( + "no_device", (1, 1024, 4096, 4096), PerformanceTimer, torch.bfloat16, (None, 0) + ), + PerfCounterManagerTestConfig( + "a100", + (1, 1024, 4096, 4096), + CUDAPerformanceTimer, + torch.bfloat16, + ("A100", 2e12), + ), +] + + +@unittest.skipIf( + not TORCH_VERSION_AFTER_2_5, "TransformerPerformanceCounter requires torch >= 2.5+." +) +@unittest.skipIf( + not torch.cuda.is_available(), "TransformerPerformanceCounter requires CUDA" +) +@parameterized_class( + [asdict(cfg) for cfg in PERFCOUNTERMANAGER_TEST_CONFIGS], + class_name_func=get_test_name, +) +class TestTransformerPerformanceCounter(unittest.TestCase): + @classmethod + def setUpClass(cls): + shape, timer_cls, dtype = cls.shape, cls.timer_cls, cls.dtype + batch_size, query_len, in_features, out_features = shape + num_tokens = batch_size * query_len + element_size = dtype.itemsize + a = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + b = torch.randn(in_features, out_features, dtype=dtype, device="cuda") + # Set up device spec + device_name, bandwidth = cls.device_spec + if device_name is not None: + with patch_device(device_name): + device_spec = CUDADeviceSpec(dtype=torch.bfloat16, bandwidth=bandwidth) + + else: + device_spec = None + + # Stateful class level objects, which will be used in individual tests + cls.cm = cm = TransformerPerformanceCounter( + timer_cls=timer_cls, device_spec=device_spec + ) + cls.FLOAT_TOL = 1e-5 + cls.expected = expected = {} + + # Start count for a + start = time.perf_counter() + with cm.count("a", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + + latency = end - start + expected_flops = 2 * num_tokens * in_features * out_features + expected_io = ( + num_tokens * in_features + + in_features * out_features + + num_tokens * out_features + ) * element_size + + expected["a"] = PerfCounterResult( + name="a", + latency=latency, + flops=expected_flops, + io=expected_io, + total_flops=expected_flops, + total_io=expected_io, + ) + + # Start count for b + start = time.perf_counter() + with cm.count("b", num_tokens=num_tokens): + _ = torch.matmul(a, b) + end = time.perf_counter() + latency = end - start + + expected["b"] = PerfCounterResult( + name="b", + latency=latency, + flops=expected_flops, + io=expected_io, + total_flops=cm.total_flops, + total_io=cm.total_io, + ) + + def test_perf_stats_a(self): + cm: TransformerPerformanceCounter = self.cm + expected = self.expected["a"] + + counts = cm.get_counts() + assert "a" in counts + + # Check captured performance stats + psa: PerformanceStats = counts["a"] + # Raw metrics + # Latency won't be exact since timing external to the profiler + assert abs(psa.latency - expected.latency) < 1e-1 # +/- 100ms + assert psa.total_flops == expected.flops + assert psa.total_io == expected.io + + # Derived metrics + assert psa.token_throughput == psa.num_tokens / psa.latency + assert psa.achieved_flops_per_s == psa.total_flops / psa.latency + assert psa.achieved_bandwidth == psa.total_io / psa.latency + + def test_perf_stats_b(self): + cm: TransformerPerformanceCounter = self.cm + assert "a" in cm.counts + assert "b" in cm.counts + psa = cm.counts["a"] + psb = cm.counts["b"] + expected = self.expected["b"] + assert abs(psb.latency - expected.latency) < 1e-1 # +/- 100ms + assert psb.total_flops == expected.flops + assert psb.total_io == expected.io + + # check that **total** flops and io after matmul `b` has run accounts for both matmuls + # also check that these global properties are updated correctly in the manager object + assert ( + expected.total_flops == psa.total_flops + psb.total_flops == cm.total_flops + ) + assert expected.total_io == psa.total_io + psb.total_io == cm.total_io + assert cm.total_time == psa.latency + psb.latency + + def test_stats_summary(self): + cm: TransformerPerformanceCounter = self.cm + FLOAT_TOL = self.FLOAT_TOL + psa = cm.counts["a"] + psb = cm.counts["b"] + summary: PerformanceStats = cm.stats_summary + + # Raw stats + assert summary.num_tokens == psa.num_tokens + psb.num_tokens + assert summary.total_io == psa.total_io + psb.total_io + assert summary.total_flops == psa.total_flops + psb.total_flops + assert summary.latency == psa.latency + psb.latency + + # Derived stats + expected_token_throughput = (psa.num_tokens + psb.num_tokens) / ( + psa.latency + psb.latency + ) + expected_io_throughput = (psa.total_io + psb.total_io) / ( + psa.latency + psb.latency + ) + expected_flops_throughput = (psa.total_flops + psb.total_flops) / ( + psa.latency + psb.latency + ) + assert abs(summary.token_throughput - expected_token_throughput) < FLOAT_TOL + assert abs(summary.achieved_bandwidth - expected_io_throughput) < FLOAT_TOL + assert abs(summary.achieved_flops_per_s - expected_flops_throughput) < FLOAT_TOL + + device_spec = cm.device_spec + if device_spec is not None: + expected_bandwidth_utilization = ( + expected_io_throughput / device_spec.bandwidth + ) + expected_flops_utilization = ( + expected_flops_throughput / device_spec.flops_per_s + ) + assert ( + abs(summary.bandwidth_utilization - expected_bandwidth_utilization) + < FLOAT_TOL + ) + assert ( + abs(summary.flops_utilization - expected_flops_utilization) < FLOAT_TOL + ) + else: + assert summary.bandwidth_utilization is None + assert summary.flops_utilization is None + + def test_json(self): + cm: TransformerPerformanceCounter = self.cm + psa: PerformanceStats = cm.counts["a"] + psb: PerformanceStats = cm.counts["b"] + device_spec: Union[DeviceSpec, None] = cm.device_spec + + with tempfile.TemporaryDirectory() as tmp_dir: + json_path = Path(tmp_dir) / "test.json" + cm.to_json(json_path) + + with open(json_path, "r") as f: + perf_dict = json.load(f) + + assert "a" in perf_dict + assert "b" in perf_dict + + # Test basic stats are recorded properly + assert perf_dict["a"]["num_tokens"] == psa.num_tokens + assert perf_dict["a"]["total_io"] == psa.total_io + assert perf_dict["a"]["total_flops"] == psa.total_flops + assert perf_dict["a"]["latency"] == psa.latency + + assert perf_dict["b"]["num_tokens"] == psb.num_tokens + assert perf_dict["b"]["total_io"] == psb.total_io + assert perf_dict["b"]["total_flops"] == psb.total_flops + assert perf_dict["b"]["latency"] == psb.latency + + # Test derived properties are present + perf_dict["a"]["achieved_flops_per_s"] == psa.achieved_flops_per_s + perf_dict["a"]["achieved_bandwidth"] == psa.achieved_bandwidth + perf_dict["b"]["achieved_flops_per_s"] == psb.achieved_flops_per_s + perf_dict["b"]["achieved_bandwidth"] == psb.achieved_bandwidth + + if device_spec is not None: + assert perf_dict["a"]["device_flops_per_s"] == device_spec.flops_per_s + assert perf_dict["a"]["device_bandwidth"] == device_spec.bandwidth + assert ( + perf_dict["a"]["theoretical_io_latency"] + == psa.theoretical_io_latency + ) + assert ( + perf_dict["a"]["theoretical_compute_latency"] + == psa.theoretical_compute_latency + ) + assert ( + perf_dict["a"]["bandwidth_utilization"] == psa.bandwidth_utilization + ) + assert perf_dict["a"]["flops_utilization"] == psa.flops_utilization + + assert perf_dict["b"]["device_flops_per_s"] == device_spec.flops_per_s + assert perf_dict["b"]["device_bandwidth"] == device_spec.bandwidth + assert ( + perf_dict["b"]["theoretical_io_latency"] + == psb.theoretical_io_latency + ) + assert ( + perf_dict["b"]["theoretical_compute_latency"] + == psb.theoretical_compute_latency + ) + assert ( + perf_dict["b"]["bandwidth_utilization"] == psb.bandwidth_utilization + ) + assert perf_dict["b"]["flops_utilization"] == psb.flops_utilization diff --git a/test/profiler/utils.py b/test/profiler/utils.py new file mode 100644 index 0000000000..7b2b999809 --- /dev/null +++ b/test/profiler/utils.py @@ -0,0 +1,103 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional +from unittest.mock import patch + +import torch + +from torchao.profiler import PerformanceTimer + + +@contextmanager +def patch_device(device_name): + with patch("torch.cuda.get_device_name", return_value=device_name): + yield + + +@dataclass(frozen=True) +class PerfCounterTestConfig: + name: str + batch_size: int + seqlen: int + dtype: torch.dtype + num_hidden_layers: int + hidden_size: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + + +def get_leaf_nodes(count_keys, module_name): + return [k for k in count_keys if k.endswith(module_name)] + + +def qkv_proj_io_check(model_config, batch_size, seqlen, element_size): + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = model_config.hidden_size * model_config.hidden_size * element_size + output_size = batch_size * seqlen * model_config.hidden_size * element_size + return input_size + weight_size + output_size + + +def attn_io_check(model_config, batch_size, seqlen, element_size): + # queries, keys, values -> factor of 3 + input_size = (batch_size * seqlen * model_config.hidden_size * 3) * element_size + output_size = (batch_size * seqlen * model_config.hidden_size) * element_size + return input_size + output_size + + +def ffn_io_check(model_config, batch_size, seqlen, element_size, module_name): + assert module_name in ["up_proj", "gate_proj", "down_proj"] + + if module_name == "down_proj": + input_size = batch_size * seqlen * model_config.intermediate_size * element_size + else: + input_size = batch_size * seqlen * model_config.hidden_size * element_size + weight_size = ( + model_config.hidden_size * model_config.intermediate_size * element_size + ) + if module_name == "down_proj": + output_size = batch_size * seqlen * model_config.hidden_size * element_size + else: + output_size = ( + batch_size * seqlen * model_config.intermediate_size * element_size + ) + + return input_size + weight_size + output_size + + +@dataclass(frozen=True) +class PerfStatsTestConfig: + label: str + num_tokens: int + latency: float + total_flops: float + total_io: float + flops_summary: dict + io_summary: dict + flop_counts: dict + io_counts: dict + device_bandwidth: Optional[float] = None + device_flops_per_s: Optional[float] = None + + +def get_test_name(cls, num, params_dict): + return f"{cls.__name__}_{num}_{params_dict['name']}" + + +@dataclass(frozen=True) +class PerfCounterResult: + name: str + latency: float + flops: float + io: float + total_flops: float + total_io: float + + +@dataclass +class PerfCounterManagerTestConfig: + name: str + shape: tuple[int] + timer_cls: PerformanceTimer + dtype: torch.dtype + device_spec: tuple[Optional[str], int] diff --git a/torchao/_models/llama/perf_profile.py b/torchao/_models/llama/perf_profile.py new file mode 100644 index 0000000000..1a0d4e36c0 --- /dev/null +++ b/torchao/_models/llama/perf_profile.py @@ -0,0 +1,442 @@ +""" + +## Performance Profiling Example + +An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. +- Outputs from gpt-fast are prefixed with GPT-Fast +- Outputs from `torchao.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. + +## Usage +```python +python perf_profile.py --prompt "Hello my name is" --checkpoint_path path/to/model.pth --num_samples 1 --max_new_tokens 2 --save_path performance_stats.json +``` +where `checkpoint_path` is the checkpoint path of the converted model weights per `gpt-fast` and `save_path` specifies where to save performance stats. + + +Running the above command for `llama2-7b` should print the following, with accumulated stats saved to `performance_stats.json` + +``` +Loading model ... +Time to load model: 20.14 seconds + +============================== + +Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB) +Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05) +Active params, Total Params: 6607343616, 6738415616 + +============================== + +TransformerPerfCounter Metrics +PREFILL_SEQLEN-6: + Latency = 1.26 s + Tokens + Total: 6 tokens + Throughput: 5 tokens/s + IO + Total: 13.25 GB + Throughput: 10.54 GB/s + Theoretical Latency: 14.15 ms + FLOPs + Total: 79.31 GFLOPs + Throughput: 63.06 GFLOPs/s + Theoretical Latency: 2.23 ms + Utilization + Bandwidth: 0.0113 % + FLOPs: 0.0018 % + +============================== + +TransformerPerfCounter Metrics +DECODE_CTX-6_NUM_TOKS-1: + Latency = 0.16 s + Tokens + Total: 1 tokens + Throughput: 6 tokens/s + IO + Total: 13.22 GB + Throughput: 83.27 GB/s + Theoretical Latency: 14.13 ms + FLOPs + Total: 13.22 GFLOPs + Throughput: 83.24 GFLOPs/s + Theoretical Latency: 0.37 ms + Utilization + Bandwidth: 0.0890 % + FLOPs: 0.0023 % + +============================== + +Generated text for sample 0: Hello, my name is [Name + +GPTFast Sample Metrics + Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec + Bandwidth achieved: 17.22 GB/s + +============================== + +GPTFast Aggregate Stats + Average tokens/sec: 1.28 + Memory used: 13.51 GB + +============================== + +TransformerPerfCounter +Performance Summary: + Latency = 1.42 s + Tokens + Total: 7 tokens + Throughput: 5 tokens/s + IO + Total: 26.47 GB + Throughput: 18.69 GB/s + Theoretical Latency: 28.28 ms + FLOPs + Total: 92.53 GFLOPs + Throughput: 65.33 GFLOPs/s + Theoretical Latency: 2.60 ms + Utilization + Bandwidth: 0.0200 % + FLOPs: 0.0018 % + +Saving performance results to performance_stats.json +``` + +**Notes** +- The discrepancy between `gpt-fast` token throughput and that of `TransformerPerformanceCounter` is due to the fact that gpt-fast` only counts generated tokens (no prefill) +-- so even though the `prefill` phase technically generates `len(prompt) + 1` tokens, it counts the number of tokens generated during this phase as `1`, +whereas `TransformerPerformanceCounter` includes all `prefill` tokens in the total token count. +""" + +import textwrap +import time +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +from torch.nn.attention import SDPBackend + +from torchao._models.llama.model import Transformer +from torchao._models.llama.tokenizer import get_tokenizer +from torchao.profiler import ( + CUDADeviceSpec, + TransformerPerformanceCounter, + total_model_params, +) + +DEVICE_SPEC: CUDADeviceSpec +PERF_COUNTER: TransformerPerformanceCounter +PERF_COUNTER_PREFIX = "TransformerPerfCounter" +GPT_FAST_PREFIX = "GPTFast" +DELIMITER = "\n" + "=" * 30 + "\n" + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet supported") + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + seqlen = input_pos.shape[-1] + num_tokens = input_pos.numel() + assert num_tokens == seqlen + + step_name = f"prefill_seqlen-{seqlen}".upper() + with PERF_COUNTER.count(step_name, num_tokens=num_tokens): + logits = model(x, input_pos) + next_token = sample(logits, **sampling_kwargs)[0] + print(DELIMITER) + stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) + print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") + + return next_token + + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + context_len = input_pos[-1].item() + num_tokens = input_pos.numel() + assert input_pos.shape[-1] == 1 + assert num_tokens == 1 + + step_name = f"decode_ctx-{context_len}_num_toks-{num_tokens}".upper() + with PERF_COUNTER.count(step_name, num_tokens=num_tokens): + logits = model(x, input_pos) + next_token = sample(logits, **sampling_kwargs) + print(DELIMITER) + stats_str = PERF_COUNTER.print_summary(labels=[step_name], show=False) + print(f"{PERF_COUNTER_PREFIX} Metrics\n{stats_str}") + + return next_token + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.nn.attention.sdpa_kernel( + backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill( + model, prompt.view(1, -1), input_pos, **sampling_kwargs + ).clone() + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + max_new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq[T + 1 :] = torch.cat(generated_tokens) + + return seq + + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision): + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +def main( + prompt: str, + num_samples: int, + max_new_tokens: int, + top_k: int, + temperature: float, + checkpoint_path: Union[Path, str], + save_path: Union[Path, str], + device: str = "cuda", + precision: torch.dtype = torch.bfloat16, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print(f"{GPT_FAST_PREFIX}") + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + global DEVICE_SPEC + global PERF_COUNTER + + DEVICE_SPEC = CUDADeviceSpec(dtype=precision) + PERF_COUNTER = TransformerPerformanceCounter(depth=3, device_spec=DEVICE_SPEC) + print(DELIMITER) + print(f"{PERF_COUNTER_PREFIX}") + print(f"Using {DEVICE_SPEC}") + print(f"Model Config: {model.config}") + + num_active_params = total_model_params(model, exclude_embeddings=True) + num_params = total_model_params(model, exclude_embeddings=False) + model_size = num_params * precision.itemsize + print(f"Active params, Total Params: {num_active_params}, {num_params}") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + + aggregate_metrics = { + "tokens_per_sec": [], + } + + start = 0 + + for i in range(start, num_samples): + t0 = time.perf_counter() + + y = generate( + model, + encoded, + max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + + t = time.perf_counter() - t0 + txt = tokenizer.decode(y.tolist()) + print(DELIMITER) + print(f"{GPT_FAST_PREFIX}") + print(f"Generated text for sample {i}: {txt}\n") + + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + sample_metrics = textwrap.dedent(f"""\ + {GPT_FAST_PREFIX} Sample Metrics + Time for inference {i+1}: {prompt_length} prompt tokens {tokens_generated} tokens generated, {t:.02f} sec total, {tokens_sec:.02f} tokens/sec + Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s""") + print( + textwrap.indent( + sample_metrics, + prefix=" ", + predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), + ) + ) + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + + # First print aggregate stats from original gpt-fast script + print(DELIMITER) + gpt_stats = textwrap.dedent(f"""\ + {GPT_FAST_PREFIX} Aggregate Stats + Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} + Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB""") + + print( + textwrap.indent( + gpt_stats, + prefix=" ", + predicate=lambda line: not line.startswith(GPT_FAST_PREFIX), + ) + ) + + # Print performance summary from TransformerPerformanceCounter + print(DELIMITER) + total_stats_str = PERF_COUNTER.print_summary(show=False) + print(f"{PERF_COUNTER_PREFIX}\n{total_stats_str}") + print(f"\nSaving performance results to {save_path}") + PERF_COUNTER.to_json(save_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="TransformerPerformanceCounter Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=2, help="Maximum number of new tokens." + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("./checkpoints/7B/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "--save_path", + type=Path, + default=Path("performance_stats.json"), + help="Path to save performance stats.", + ) + args = parser.parse_args() + main(**vars(args)) diff --git a/torchao/profiler/__init__.py b/torchao/profiler/__init__.py new file mode 100644 index 0000000000..e748438e87 --- /dev/null +++ b/torchao/profiler/__init__.py @@ -0,0 +1,23 @@ + +# Re-exports +from .device_spec import CUDADeviceSpec, DeviceSpec +from .performance_counter import ( + CUDAPerformanceTimer, + PerformanceCounterMode, + PerformanceStats, + PerformanceTimer, + TransformerPerformanceCounter, +) +from .utils import total_model_params + +__all__ = [ + "CUDAPerformanceTimer", + "PerformanceCounterMode", + "PerformanceStats", + "PerformanceTimer", + "TransformerPerformanceCounter", + "CUDADeviceSpec", + "DeviceSpec", + "total_model_params", +] + diff --git a/torchao/profiler/device_spec.py b/torchao/profiler/device_spec.py new file mode 100644 index 0000000000..040367583f --- /dev/null +++ b/torchao/profiler/device_spec.py @@ -0,0 +1,421 @@ +from dataclasses import dataclass, field, fields +from typing import Dict, Optional, Union + +import torch + +"""This module contains the device specs for theoretical peak performance calculations. + +- Contains a list of available chips and their corresponding theoretical peak FLOPs performance for various torch.dtypes. +- Exposes a DeviceSpec interface and a concrete CUDADeviceSpec implementation for CUDA gpus. Extendable to other device types. +- Where possible, the CUDADeviceSpec auto-populates its fields by utilizing `torch.cuda` API and `triton.runtime.driver`. + +""" +# Copied from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py +_AVAILABLE_GPU_SPECS: Dict[str, Dict[Union[str, torch.dtype], float]] = { + # Hopper + # source: https://resources.nvidia.com/en-us-tensor-core + "h100 nvl": { + torch.float64: 67e12, + torch.float32: 133.8e12, + "tfloat32": 989.4e12, + torch.bfloat16: 1978.8e12, + torch.float16: 1978.8e12, + torch.int8: 3957.8e12, + }, + "h100 sxm": { + torch.float64: 33.5e12, + torch.float32: 66.9e12, + "tfloat32": 494.7e12, + torch.bfloat16: 989.4e12, + torch.float16: 989.4e12, + torch.int8: 1978.9e12, + }, + "h100 pcie": { + torch.float64: 25.6e12, + torch.float32: 51.2e12, + "tfloat32": 378e12, + torch.bfloat16: 756e12, + torch.float16: 756e12, + torch.int8: 1513e12, + }, + # Ada + # source: https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf + "rtx 4090": { + torch.float32: 82.6e12, + "tfloat32": 82.6e12, + torch.bfloat16: 82.6e12, + torch.float16: 82.6e12, + torch.int8: 660.6e12, + "int4": 1321.2e12, + }, + "rtx 4080": { + torch.float32: 48.7e12, + "tfloat32": 48.7e12, + torch.bfloat16: 48.7e12, + torch.float16: 48.7e12, + torch.int8: 389.9e12, + "int4": 779.8e12, + }, + "l4": { + torch.float32: 30.3e12, + "tfloat32": 60e12, + torch.bfloat16: 121e12, + torch.float16: 121e12, + torch.int8: 242e12, + "int4": 484e12, + }, + "l40": { + torch.float32: 90.5e12, + "tfloat32": 90.5e12, + torch.bfloat16: 181e12, + torch.float16: 181e12, + torch.int8: 362e12, + "int4": 724e12, + }, + # Ampere + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + "a100": { + torch.float64: 9.7e12, + torch.float32: 19.5e12, + "tfloat32": 156e12, + torch.bfloat16: 312e12, + torch.float16: 312e12, + torch.int8: 624e12, + }, + "a6000": { + torch.float32: 38.7e12, + "tfloat32": 77.4e12, + torch.bfloat16: 38.7e12, + torch.float16: 38.7e12, + torch.int8: 309.7e12, + "int4": 619.3e12, + }, + "a40": { + torch.float32: 37.4e12, + "tfloat32": 74.8e12, + torch.bfloat16: 37.4e12, + torch.float16: 37.4e12, + torch.int8: 299.3e12, + "int4": 598.7e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf + "a10g": { + torch.float32: 31.2e12, + "tfloat32": 62.5e12, + torch.bfloat16: 125e12, + torch.float16: 125e12, + torch.int8: 250e12, + "int4": 500e12, + }, + "rtx 3090 ti": { + torch.float32: 40e12, + "tfloat32": 40e12, + torch.bfloat16: 40e12, + torch.float16: 40e12, + torch.int8: 320e12, + "int4": 640e12, + }, + "rtx 3090": { + torch.float32: 35.6e12, + "tfloat32": 35.6e12, + torch.bfloat16: 35.6e12, + torch.float16: 35.6e12, + torch.int8: 284e12, + "int4": 568e12, + }, + "rtx 3080 ti": { + torch.float32: 34.1e12, + "tfloat32": 34.1e12, + torch.bfloat16: 34.1e12, + torch.float16: 34.1e12, + torch.int8: 272.8e12, + "int4": 546.6e12, + }, + "rtx 3080": { + torch.float32: 29.8e12, + "tfloat32": 29.8e12, + torch.bfloat16: 29.8e12, + torch.float16: 29.8e12, + torch.int8: 238e12, + "int4": 476e12, + }, + "rtx 3070": { + torch.float32: 20.3e12, + "tfloat32": 20.3e12, + torch.bfloat16: 20.3e12, + torch.float16: 20.3e12, + torch.int8: 162.6e12, + "int4": 325.2e12, + }, + # Turing + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + "t4": { + torch.float32: 8.1e12, + torch.float16: 65e12, + torch.int8: 130e12, + "int4": 260e12, + }, + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf + "quadro rtx 5000": { + torch.float32: 11.2e12, + torch.float16: 89.2e12, + }, + "rtx 2080 super": { + torch.float32: 11.2e12, + torch.float16: 22.3e12, + torch.int8: 178.4e12, + "int4": 356.8e12, + }, + "rtx 2080 ti": { + torch.float32: 14.2e12, + torch.float16: 28.5e12, + torch.int8: 227.7e12, + "int4": 455.4e12, + }, + "rtx 2080": { + torch.float32: 10.6e12, + torch.float16: 21.2e12, + torch.int8: 169.6e12, + "int4": 339.1e12, + }, + # https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.pdf + "rtx 2070 super": { + torch.float32: 9.1e12, + torch.float16: 18.1e12, + torch.int8: 145e12, + "int4": 290e12, + }, + "titan rtx": { + torch.float32: 16.3e12, + torch.float16: 32.6e12, + torch.int8: 261e12, + "int4": 522e12, + }, + # Volta + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "v100 sxm": { + torch.float64: 7.8e12, + torch.float32: 15.7e12, + torch.float16: 125e12, + }, + "v100 pcie": { + torch.float64: 7e12, + torch.float32: 14e12, + torch.float16: 112e12, + }, + "v100s pcie": { + torch.float64: 8.2e12, + torch.float32: 16.4e12, + torch.float16: 130e12, + }, +} + + +# Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/utilities/throughput.py +def get_chip_name(device: int = 0) -> str: + device_name = torch.cuda.get_device_name(device) + chip = device_name.lower() + + if "h100" in chip: + if "hbm3" in chip: + chip = "h100 sxm" + elif "nvl" in chip: + chip = "h100 nvl" + elif "pcie" in chip or "hbm2e" in chip: + chip = "h100 pcie" + elif "l4" in chip: + chip = "l40" if "tesla" in chip else "l4" + elif "geforce rtx" in chip: + number = chip.split(" ")[3] + extra = "" + if "super" in chip: + extra = " super" + elif "ti" in chip: + extra = " ti" + chip = f"rtx {number}{extra}" + elif "a6000" in chip: + chip = "a6000" + elif "a100" in chip: + chip = "a100" + elif "a40" in chip: + chip = "a40" + elif "a10g" in chip: + chip = "a10g" + elif "t4" in chip: + chip = "t4" + elif "quadro rtx 5000" in chip: + chip = "quadro rtx 5000" + elif "titan rtx" in chip: + chip = "titan rtx" + elif "v100-sxm" in chip: + chip = "v100 sxm" + elif "v100-pcie" in chip: + chip = "v100 pcie" + elif "v100s-pcie" in chip: + chip = "v100s pcie" + else: + chip = None + return chip + + +def get_vram(device: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device) + return device_props.total_memory + + +def get_bandwidth(device: int = 0) -> int: + try: + from triton.testing import get_dram_gbps + + bandwidth = get_dram_gbps(device) * 1e9 + except ImportError: + print("Could not import triton to get DRAM Gbps. Please install triton") + bandwidth = None + return bandwidth + + +def get_flops_by_dtype(chip_name: str) -> dict[torch.dtype, float]: + return _AVAILABLE_GPU_SPECS.get(chip_name, None) + + +@dataclass +class DeviceSpec: + """ + Abstract device specs for theoretical peak performance calculations. + + Fields will be auto-populated in __post_init__ if not already specified + and if data is available + - bandwidth (bytes /s) + - flops_per_s (FLOP / s) + - vram (bytes) + - dtype (torch.dtype) dtype used for theoretical peak performance + - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s + """ + + device_type: str + name: Optional[str] = None + bandwidth: Optional[int] = None + flops_per_s: Optional[int] = None + vram: Optional[int] = None + dtype: Optional[torch.dtype] = None + flops_by_dtype: dict = field(default_factory=dict) + + def _post_init_check(self): + assert ( + self.bandwidth is not None + ), "GPU bandwidth is None - please specify the bandwidth in GB/s in order to enable speed of light calculations" + assert ( + self.dtype is not None + ), "GPU dtype is None - please specify the dtype in order to enable speed of light calculations" + assert ( + self.flops_per_s is not None + ), "GPU flops_per_s is None - please specify the flops_per_s in FLOP/s in order to enable speed of light calculations" + self.flops_by_dtype.update({self.dtype: self.flops_per_s}) + + # Not needed for downstream calculations atm, no need to assert + if self.vram is None: + print("GPU vram is None - please specify the vram in bytes") + + def __setattr__(self, name, value): + # Check if the attribute is already defined + if name in {f.name for f in fields(self)}: + super().__setattr__(name, value) + else: + raise AttributeError( + f"Cannot add new attribute '{name}' to {self.__class__.__name__}" + ) + + def __str__(self): + if self.bandwidth is not None: + formatted_bw = f"{self.bandwidth / 1e9:,.1f}GB/s" + if self.flops_per_s is not None: + formatted_flops = f"{self.flops_per_s / 1e12:,.1f}TFLOPs" + if self.vram is not None: + formatted_vram = f"{self.vram / 1e9:,.1f}GB" + return f"DeviceSpec(device_type={self.device_type}, name={self.name}, dtype={self.dtype}, bandwidth={formatted_bw}, flops={formatted_flops}, vram={formatted_vram})" + + @property + def roofline_balancepoint(self): + """ + Arithmetic intensity (FLOP / byte) transition point from + memory-bound to compute-bound regime. + + This is the ridgepoint of the roofline curve. + """ + assert ( + self.bandwidth is not None + ), "Please set bandwidth in order to calculate roofline balancepoint" + assert ( + self.flops_per_s is not None + ), "Please set flops_per_s in order to calculate roofline balancepoint" + + return self.flops_per_s / self.bandwidth + + +@dataclass +class CUDADeviceSpec(DeviceSpec): + """ + CUDA specs for theoretical peak performance, conformant with DeviceSpec interface. + + Fields will be auto-populated in __post_init__ if not specified + and if data is available. + + See _AVAILABLE_GPU_SPECS for a list of available chip data. + + Fields and expected units: + - device (int): CUDA device index + - name (str): name of the device + - bandwidth (bytes /s): memory bandwidth in bytes / s + - flops_per_s (FLOP / s): FLOPs per second + - vram (bytes): VRAM in bytes + - dtype (torch.dtype): dtype used for theoretical peak performance + - flops_by_dtype (dict[Union[torch.dtype, str], float]): mapping from dtype to FLOP / s + - use_tensorcores (bool): whether to use tensorcores if dtype == torch.float32 + """ + + device_type: str = "cuda" + # Device index + device: Optional[int] = 0 + # Whether to use tfloat32 FLOPs for dtype == torch.float32 + # We assume that tensorcores will always be used for fp16, int8, and other sub-single precision dtypes + use_tensorcores: bool = True + + def __post_init__(self): + # Populate fields if not already populated + self.name = torch.cuda.get_device_name(self.device) + + # Memory bandwidth in bytes / s + if self.bandwidth is None: + self.bandwidth = get_bandwidth() + + # FLOPs / s + if self.flops_per_s is None: + chip_name = get_chip_name(self.device) + if chip_name is None: + print(f"No FLOPs data available for device name {self.name}") + else: + flops_by_dtype = get_flops_by_dtype(chip_name) + if flops_by_dtype is not None: + self.flops_by_dtype.update(flops_by_dtype) + + # Populate flops if not already populated + if flops_by_dtype is not None and self.dtype in flops_by_dtype: + self.flops_per_s = flops_by_dtype[self.dtype] + + if self.dtype == torch.float32: + use_tf32 = "tfloat32" in flops_by_dtype and self.use_tensorcores + + if use_tf32: + self.flops_per_s = flops_by_dtype["tfloat32"] + else: + print( + f"Could not find FLOPs for dtype {self.dtype} for device {self.name}" + ) + # Vram + if self.vram is None: + self.vram = get_vram() + + # Issue post check warnings + self._post_init_check() diff --git a/torchao/profiler/performance_counter.py b/torchao/profiler/performance_counter.py new file mode 100644 index 0000000000..d79625d55c --- /dev/null +++ b/torchao/profiler/performance_counter.py @@ -0,0 +1,597 @@ +import inspect +import json +import math +import textwrap +import time +import warnings +from collections import defaultdict +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import asdict, dataclass +from functools import partial +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from torch.utils._pytree import tree_map +from torch.utils.flop_counter import FlopCounterMode + +from .device_spec import DeviceSpec + +aten = torch.ops.aten + + +class DeviceInfoMissing(UserWarning): + pass + + +# Prevent excessive output +warnings.simplefilter("once", DeviceInfoMissing) + + +class PerformanceCounterMode(FlopCounterMode): + """ + ``PerformanceCounterMode`` extends FlopCounterMode to track IO in addition to flops. + + It does this using a ``TorchDispatchMode`` per `FlopCounterMode` and tracks the + inputs and outputs of each operator, organized by module. + + In addition to the methods exposed by FlopCounterMode, the following methods are + available: + - ``get_io_counts``: returns a dictionary of module names and their associated IO counts by aten operator + - ``get_total_io``: returns the total number of IO operations across all modules + - ``get_summary_io_counts``: returns a summary of the IO counts for each module (totals by operator) + - ``get_summary_flop_counts``: returns a summary of the flop counts for each module (totals by operator) + """ + + def __init__(self, display=False, depth=10, debug=False): + self.debug = debug + self.io_counts = defaultdict(lambda: defaultdict(int)) + super().__init__(display=display, depth=depth) + + def get_io_counts(self): + return {k: dict(v) for k, v in self.io_counts.items()} + + def get_total_io(self): + return sum(self.io_counts["Global"].values()) + + def _get_io_sizes(self, args): + sizes = tree_map( + lambda x: x.numel() * x.element_size() + if isinstance(x, torch.Tensor) + else 0, + args, + ) + if not hasattr(sizes, "__len__"): + sizes = [sizes] + return sizes + + def get_summary_flop_counts(self): + flop_counts = self.get_flop_counts() + return {k: sum(v.values()) for k, v in flop_counts.items()} + + def get_summary_io_counts(self): + io_counts = self.get_io_counts() + return {k: sum(v.values()) for k, v in io_counts.items()} + + def _nearest_power_of_10(self, x): + if x == 0: + return x, 0 + + power = int(math.floor(math.log10(abs(x)) / 3)) + scaled_value = x / (10 ** (3 * power)) + + return scaled_value, power + + def pretty_summary_counts(self, type="flops", precision=2, depth=None): + assert type in ["flops", "io"] + metric_units = { + 0: "", + 1: "k", + 2: "M", + 3: "G", + 4: "T", + 5: "P", + 6: "E", + 7: "Z", + 8: "Y", + } + + if depth is None: + depth = self.depth + summary_counts = ( + self.get_summary_flop_counts() + if type == "flops" + else self.get_summary_io_counts() + ) + keys_to_print = [k for k in summary_counts.keys() if len(k.split(".")) <= depth] + units = "FLOPs" if type == "flops" else "B" + summary_str = [] + for k in sorted(keys_to_print, key=lambda x: len(x.split("."))): + if k == "Global" or k is None: + continue + spaces = " " * (len(k.split(".")) - 1) + scaled_val, power = self._nearest_power_of_10(summary_counts[k]) + formatted_val = f"{scaled_val:.{precision}f}{metric_units[power]}{units}" + summary_str.append(f"{spaces}{k}: {formatted_val}") + + return "\n".join(summary_str) + + def _count_io(self, func_packet, out, args, kwargs): + arg_sizes = self._get_io_sizes(args) + kwargs_sizes = self._get_io_sizes(kwargs.values()) + out_sizes = self._get_io_sizes(out) + arg_size, kwargs_size, out_size = ( + sum(arg_sizes), + sum(kwargs_sizes), + sum(out_sizes), + ) + return arg_size, kwargs_size, out_size + + def _count_flops(self, func_packet, out, args, kwargs): + if func_packet in self.flop_registry: + flop_count_func = self.flop_registry[func_packet] + flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] + arg_size, kwarg_size, out_size = self._count_io( + func_packet, out, args, kwargs + ) + total_size = arg_size + kwarg_size + out_size + + for par in set(self.mod_tracker.parents): + if self.debug: + print(f"Counting flops for {par}, {func_packet}: {flop_count}") + print( + f"Counting io for {par}, {func_packet}: {sum([arg_size, kwarg_size, out_size])} = {arg_size} + {kwarg_size} + {out_size}" + ) + self.flop_counts[par][func_packet] += flop_count + self.io_counts[par][func_packet] += total_size + + return out + + +class PerformanceTimer: + """ + Context manager that records the latency, io, and flops of a torch operator / module. + + Timing is done using `time.perf_counter` and can be overridden to use a different + timer (see `CUDAPerformanceTimer`). + + IO and FLOPs are recorded using `PerformanceCounterMode`. + + Available attributes: + name: str + precision: int + display: bool + depth (int): passed to `PerformanceCounterMode` if displaying and determines depth of module tree to display. + **Note**: these attributes are primarily used for debugging when using the `PerformanceTimer` standalone. + The TransformerPerformanceCounter class is a higher-level API that should be used instead. + + """ + + def __init__(self, name, precision=1, display=False, depth=10): + self.name = name + self.precision = precision + self.display = display + self.depth = depth + self.perf_counter = PerformanceCounterMode(display=display, depth=depth) + + def __enter__(self): + self.start = time.perf_counter() + self.perf_counter.__enter__() + return self + + def _print_exit_msg(self): + gflops = round(self.total_flops / 1e9, self.precision) + ms = round(self.latency * 1e3, self.precision) + if self.display: + print(f"{self.name.upper()}: latency = {ms} ms, FLOPS = {gflops} GFLOPs") + + def __exit__(self, type, value, traceback): + self.end = time.perf_counter() + # Convert to ms + self.latency = self.end - self.start + self.perf_counter.__exit__(type, value, traceback) + if self.display: + self._print_exit_msg() + + @property + def total_flops(self): + return self.perf_counter.get_total_flops() + + @property + def total_io(self): + return self.perf_counter.get_total_io() + + @property + def flops_table(self): + return self.perf_counter.get_table() + + def get_summary_flop_counts(self): + return self.perf_counter.get_summary_flop_counts() + + def get_summary_io_counts(self): + return self.perf_counter.get_summary_io_counts() + + @property + def flop_counts(self): + return self.perf_counter.get_flop_counts() + + @property + def io_counts(self): + return self.perf_counter.get_io_counts() + + def get_pretty_summary(self, depth): + return self.perf_counter.pretty_summary_counts( + depth=depth if depth is not None else self.depth + ) + + +class CUDAPerformanceTimer(PerformanceTimer): + """ + `PerformanceTimer` that uses `cudaEvents` to record latency. + """ + + def __enter__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + self.perf_counter = PerformanceCounterMode( + display=self.display, depth=self.depth + ) + self.perf_counter.__enter__() + return self + + def __exit__(self, type, value, traceback): + self.end.record() + torch.cuda.synchronize() + # Convert from ms to s + self.latency = self.start.elapsed_time(self.end) * 1e-3 + self.perf_counter.__exit__(type, value, traceback) + + if self.display: + self._print_exit_msg() + + +def to_nearest_power_of_10(x, precision=2): + # Dictionary mapping powers of 10 to their metric abbreviations + metric_units = {0: "", -6: "ยต", -3: "m", 6: "M", 9: "G", 12: "T"} + + # Determine the closest power of 10 + if x == 0: + return f"{x:.{precision}f}" + + power = int(math.floor(math.log10(abs(x)))) + # Adjust power to fit within the given metric units + powers = sorted(metric_units.keys()) + closest_power = min(powers, key=lambda p: abs(p - power)) + + # Calculate the value formatted to the closest power of 10 + value = x / 10**closest_power + + # Map the power to the metric unit + unit = metric_units.get(closest_power, f"e{closest_power}") + + return f"{value:,.{precision}f} {unit}" + + +class DictMixin: + """ + Enables dict-like interface to dataclasses. + """ + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __contains__(self, key): + return hasattr(self, key) + + def __iter__(self): + for key in self.__dict__: + yield key + + +def _get_property_methods(cls): + return [ + name for name, _ in inspect.getmembers(cls, lambda m: isinstance(m, property)) + ] + + +@dataclass +class PerformanceStats(DictMixin): + """ + Data struct that stores performance statistics. + + Attrs: + num_tokens (int): number of tokens processed + latency (float): latency in seconds + total_flops (int): total FLOPs + total_io (int): total data movement in bytes + flops_summary (Dict[str, int]): summary of FLOPs by module + io_summary (Dict[str, int]): summary of data movement in bytes by module + flop_counts (Dict[str, Dict[Any, int]]): FLOP counts by module and operation + io_counts (Dict[str, Dict[Any, int]]): data movement by module and operation + device_bandwidth (Optional[float]): device bandwidth in bytes per second + device_flops_per_s (Optional[float]): device FLOPs per second + + Additionally, the following derived properties are available: + token_throughput (float): number of tokens processed per second + achieved_flops_per_s (float): achieved FLOPs per second + achieved_bandwidth (float): achieved data movement in bytes per second + theoretical_io_latency (Optional[float]): theoretical I/O latency in seconds, set to None if + no device bandwidth is available. + theoretical_compute_latency (Optional[float]): theoretical compute latency in seconds, set to None if + no device FLOPs are available. + """ + + label: str + num_tokens: int + latency: float + total_flops: int + total_io: int + flops_summary: Dict[str, int] + io_summary: Dict[str, int] + flop_counts: Dict[str, Dict[Any, int]] + io_counts: Dict[str, Dict[Any, int]] + device_bandwidth: Optional[float] = None + device_flops_per_s: Optional[float] = None + + @property + def token_throughput(self): + return self.num_tokens / self.latency + + @property + def achieved_flops_per_s(self): + return self.total_flops / self.latency + + @property + def achieved_bandwidth(self): + return self.total_io / self.latency + + @property + def theoretical_io_latency(self): + if self.device_bandwidth is not None: + return self.total_io / self.device_bandwidth + else: + warnings.warn( + "Device bandwidth is not specified. Please specify the device bandwidth to enable io latency calculation" + ) + return None + + @property + def theoretical_compute_latency(self): + if self.device_flops_per_s is not None: + return self.total_flops / self.device_flops_per_s + else: + warnings.warn( + "Device flops_per_s is not specified. Please specify the device throughput to enable compute latency calculation" + ) + return None + + @property + def bandwidth_utilization(self): + if self.device_bandwidth is not None: + return self.achieved_bandwidth / self.device_bandwidth + else: + warnings.warn( + "Device bandwidth is not specified. Please specify the device bandwidth to enable bandwidth utilization calculation" + ) + return None + + @property + def flops_utilization(self): + if self.device_flops_per_s is not None: + return self.achieved_flops_per_s / self.device_flops_per_s + else: + warnings.warn( + "Device flops_per_s is not specified. Please specify the device throughput to enable flops utilization calculation" + ) + return None + + def _format(self, value, suffix, precision=2, round=True): + if round: + return to_nearest_power_of_10(value, precision=precision) + suffix + return f"{value:.{precision}f} " + suffix + + def __str__(self): + txt = textwrap.dedent(f"""\ + {self.label}: + Latency = {self._format(self.latency, "s")} + Tokens + Total: {self.num_tokens} tokens + Throughput: {self.token_throughput:,.0f} tokens/s + IO + Total: {self._format(self.total_io, "B")} + Throughput: {self._format(self.achieved_bandwidth, "B/s")} + Theoretical Latency: {self._format(self.theoretical_io_latency, "s") if self.theoretical_io_latency is not None else "N/A"} + FLOPs + Total: {self._format(self.total_flops, "FLOPs")} + Throughput: {self._format(self.achieved_flops_per_s, "FLOPs/s")} + Theoretical Latency: {self._format(self.theoretical_compute_latency, "s") if self.theoretical_compute_latency is not None else "N/A"} + Utilization + Bandwidth: {self._format(self.bandwidth_utilization, round=False, precision=4, suffix="%") if self.bandwidth_utilization is not None else "N/A"} + FLOPs: {self._format(self.flops_utilization, round=False, precision=4, suffix="%") if self.flops_utilization is not None else "N/A"}""") + + return txt + + def to_dict(self): + d = asdict(self) + # Update dict with properties + props = _get_property_methods(self.__class__) + d.update({prop: getattr(self, prop) for prop in props}) + + return d + + +class TransformerPerformanceCounter: + """ + Context manager-like class for tracking performance across multiple calls + to a Transformer model. + + Provides properties for accessing performance stats for data movement and FLOPs for each context as well as + summary stats across all contexts. + Additionally, if a device_spec is provided, theoretical peak bandwidth / FLOPs stats will be available. + + See `PerformanceStats` struct for description of tracked metrics. + + Example: + >>> manager = TransformerPerformanceCounter(device_spec=device_spec) + >>> with manager.count(label="prefill", num_tokens=x.numel()): + >>> out = model(encoded_prompt) + >>> manager.print_summary(labels=["prefill"]) # prints recorded stats for "prefill" context + >>> with manager.count(label="decode", num_tokens=1): + >>> out = model(out[-1]) + >>> manager.print_summary(labels=["decode"]) # prints recorded stats for "decode" context + >>> print(manager.print_summary) # prints accumulated stats across all contexts + """ + + def __init__( + self, + depth=10, + timer_cls: PerformanceTimer = PerformanceTimer, + device_spec: DeviceSpec = None, + ): + super().__init__() + self._counts: Dict[str, PerformanceStats] = {} + self._depth = depth + self.timer_cls = timer_cls + self.device_spec = device_spec + + @contextmanager + def count(self, label: str, num_tokens: int): + perf_timer = self.timer_cls(name=label, depth=self._depth) + perf_timer.__enter__() + try: + yield self + finally: + perf_timer.__exit__(None, None, None) + stats = PerformanceStats( + label=label, + num_tokens=num_tokens, + latency=perf_timer.latency, + total_flops=perf_timer.total_flops, + total_io=perf_timer.total_io, + flops_summary=perf_timer.get_summary_flop_counts(), + io_summary=perf_timer.get_summary_io_counts(), + flop_counts=perf_timer.flop_counts, + io_counts=perf_timer.io_counts, + device_bandwidth=self.device_spec.bandwidth + if self.device_spec is not None + else None, + device_flops_per_s=self.device_spec.flops_per_s + if self.device_spec is not None + else None, + ) + self._counts[label] = stats + + @property + def counts(self): + return self._counts + + def get_counts(self): + return self._counts + + @property + def total_flops(self): + return sum(count.total_flops for count in self._counts.values()) + + @property + def total_io(self): + return sum(count.total_io for count in self._counts.values()) + + @property + def total_tokens(self): + return sum(count.num_tokens for count in self._counts.values()) + + @property + def total_time(self): + return sum(count.latency for count in self._counts.values()) + + def _summarize_stat(self, key): + return { + label: getattr(self._counts[label], key) for label in self._counts.keys() + } + + @property + def flops_summary(self): + return self._summarize_stat(key="flops_summary") + + @property + def io_summary(self): + return self._summarize_stat(key="io_summary") + + @property + def flop_counts_summary(self): + return self._summarize_stat(key="flop_counts") + + @property + def io_counts_summary(self): + return self._summarize_stat(key="io_counts") + + @property + def stats_summary(self): + stats = PerformanceStats( + label="Performance Summary", + num_tokens=self.total_tokens, + latency=self.total_time, + total_flops=self.total_flops, + total_io=self.total_io, + flops_summary=self.flops_summary, + io_summary=self.io_summary, + flop_counts=self.flop_counts_summary, + io_counts=self.io_counts_summary, + device_bandwidth=self.device_spec.bandwidth + if self.device_spec is not None + else None, + device_flops_per_s=self.device_spec.flops_per_s + if self.device_spec is not None + else None, + ) + + return stats + + def print_summary(self, labels: list[str] = None, show: bool = False): + _print = partial(print, flush=True, end="\n") + # Delegate to __str__ of PerformanceStats for pretty printing + if labels is None: + text = str(self.stats_summary) + if show: + _print(text) + return text + else: + txts = [] + for label in labels: + text = str(self._counts[label]) + if show: + _print(text) + txts.append(text) + return "\n".join(txts) + + def to_dict(self): + # Convert flop_counts from OpOverloadPackets to str + # Then delegate to PerformanceStats `to_dict`, which updates with derived metrics (property methods) + counts = deepcopy(self._counts) + for label, label_counts in counts.items(): + counts[label]["flop_counts"] = { + mod: {str(op): count for op, count in op_count.items()} + for mod, op_count in label_counts["flop_counts"].items() + } + counts[label]["io_counts"] = { + mod: {str(op): count for op, count in op_count.items()} + for mod, op_count in label_counts["io_counts"].items() + } + counts[label] = counts[label].to_dict() + + return counts + + def to_json(self, path: Union[str, Path] = None): + d = self.to_dict() + if path: + with open(path, "w") as f: + f.write(json.dumps(d, indent=2)) + return d diff --git a/torchao/profiler/utils.py b/torchao/profiler/utils.py new file mode 100644 index 0000000000..9276dd37b1 --- /dev/null +++ b/torchao/profiler/utils.py @@ -0,0 +1,44 @@ +import inspect + +import torch + +_HUGGINGFACE_CAUSAL_LM_BASE_CLASSES = [ + "causallm", + "pretrainedmodel", + "generationmixin", +] + + +def _get_all_base_classes(object): + return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)] + + +def total_model_params( + model: torch.nn.Module, + exclude_embeddings: bool = True, + embedding_key: str = "tok_embeddings", +) -> int: + """ + Calculate total params of a HuggingFace CausalLM model or gpt-fast model + """ + num_params = sum(p.numel() for p in model.parameters()) + + # Exclude embeddings when calculating FLOP since they don't contribute to FLOP count + if exclude_embeddings: + # Not the cleanest, but check if any base class of the model is in _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + if ( + len( + set(_get_all_base_classes(model)).intersection( + _HUGGINGFACE_CAUSAL_LM_BASE_CLASSES + ) + ) + > 0 + ): + num_params -= model.model.embed_tokens.weight.numel() + elif hasattr(model, embedding_key): + num_params -= getattr(model, embedding_key).weight.numel() + else: + raise ValueError( + f"Could not find embedding in model {model.__class__.__name__}, please specify embedding attribute key" + ) + return num_params