|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Benchmark harness for GLiNER inference-time packing.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import argparse |
| 7 | +import json |
| 8 | +import time |
| 9 | +from dataclasses import asdict, dataclass |
| 10 | +from typing import Dict, List |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | +from transformers import AutoModel, AutoTokenizer |
| 15 | + |
| 16 | +from gliner.infer_packing import InferencePackingConfig, pack_requests |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class BenchmarkStats: |
| 21 | + tokens_per_s: float |
| 22 | + examples_per_s: float |
| 23 | + padding_ratio: float |
| 24 | + |
| 25 | + |
| 26 | +def _format_table(result: Dict[str, object]) -> str: |
| 27 | + lines = [] |
| 28 | + header = f"{'mode':<10} {'tokens/s':>15} {'examples/s':>15} {'padding':>12}" |
| 29 | + lines.append(header) |
| 30 | + lines.append("-" * len(header)) |
| 31 | + for mode in ("baseline", "packed"): |
| 32 | + stats: BenchmarkStats = result[mode] # type: ignore[assignment] |
| 33 | + lines.append( |
| 34 | + f"{mode:<10} {stats.tokens_per_s:>15.2e} {stats.examples_per_s:>15.2f} {stats.padding_ratio:>11.2%}" |
| 35 | + ) |
| 36 | + lines.append("") |
| 37 | + lines.append(f"Speedup (tokens/s): {result['speedup_tokens_per_s']:.2f}x") |
| 38 | + return "\n".join(lines) |
| 39 | + |
| 40 | + |
| 41 | +def _parse_args() -> argparse.Namespace: |
| 42 | + parser = argparse.ArgumentParser(description="Benchmark GLiNER inference packing.") |
| 43 | + parser.add_argument("--model", type=str, default="roberta-base", help="Model name or path") |
| 44 | + parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length") |
| 45 | + parser.add_argument("--batch_size", type=int, default=64, help="Number of requests per batch") |
| 46 | + parser.add_argument( |
| 47 | + "--scenario", |
| 48 | + type=str, |
| 49 | + default="short_uniform", |
| 50 | + choices=["short_uniform", "short_zipf", "mixed_tail", "flat_long"], |
| 51 | + help="Length distribution scenario", |
| 52 | + ) |
| 53 | + parser.add_argument("--device", type=str, default="cpu", help="Device to benchmark on") |
| 54 | + parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations") |
| 55 | + parser.add_argument("--iters", type=int, default=100, help="Number of timed iterations") |
| 56 | + return parser.parse_args() |
| 57 | + |
| 58 | + |
| 59 | +def _generate_lengths(args: argparse.Namespace) -> List[int]: |
| 60 | + batch = args.batch_size |
| 61 | + max_length = args.max_length |
| 62 | + |
| 63 | + if args.scenario == "short_uniform": |
| 64 | + rng = np.random.default_rng(1337) |
| 65 | + values = rng.integers(8, min(64, max_length) + 1, size=batch) |
| 66 | + return values.astype(int).tolist() |
| 67 | + if args.scenario == "short_zipf": |
| 68 | + rng = np.random.default_rng(2024) |
| 69 | + lengths = rng.zipf(1.2, size=batch) |
| 70 | + clipped = np.clip(lengths, 8, min(128, max_length)) |
| 71 | + return clipped.astype(int).tolist() |
| 72 | + if args.scenario == "mixed_tail": |
| 73 | + rng = np.random.default_rng(314) |
| 74 | + longs = [min(256, max_length)] |
| 75 | + if batch > 1: |
| 76 | + shorts = rng.integers(8, min(48, max_length) + 1, size=batch - 1) |
| 77 | + return longs + shorts.astype(int).tolist() |
| 78 | + return longs |
| 79 | + if args.scenario == "flat_long": |
| 80 | + return [min(256, max_length)] * batch |
| 81 | + |
| 82 | + raise ValueError(f"Unsupported scenario: {args.scenario}") |
| 83 | + |
| 84 | + |
| 85 | +def _build_requests(lengths: List[int], vocab_size: int, pad_token_id: int) -> List[Dict[str, List[int]]]: |
| 86 | + requests: List[Dict[str, List[int]]] = [] |
| 87 | + token = 0 |
| 88 | + for length in lengths: |
| 89 | + actual_len = max(1, min(int(length), vocab_size - 1)) |
| 90 | + sequence: List[int] = [] |
| 91 | + for _ in range(actual_len): |
| 92 | + value = token % vocab_size |
| 93 | + if value == pad_token_id: |
| 94 | + value = (value + 1) % vocab_size |
| 95 | + sequence.append(value) |
| 96 | + token += 1 |
| 97 | + requests.append({"input_ids": sequence}) |
| 98 | + return requests |
| 99 | + |
| 100 | + |
| 101 | +def _collate_baseline(requests: List[Dict[str, List[int]]], pad_token_id: int) -> Dict[str, torch.Tensor]: |
| 102 | + max_len = max(len(req["input_ids"]) for req in requests) |
| 103 | + batch = len(requests) |
| 104 | + input_ids = torch.full((batch, max_len), pad_token_id, dtype=torch.long) |
| 105 | + attention_mask = torch.zeros((batch, max_len), dtype=torch.long) |
| 106 | + for row, req in enumerate(requests): |
| 107 | + tokens = req["input_ids"] |
| 108 | + length = len(tokens) |
| 109 | + input_ids[row, :length] = torch.tensor(tokens, dtype=torch.long) |
| 110 | + attention_mask[row, :length] = 1 |
| 111 | + return {"input_ids": input_ids, "attention_mask": attention_mask} |
| 112 | + |
| 113 | + |
| 114 | +def _measure( |
| 115 | + model: AutoModel, |
| 116 | + inputs: Dict[str, torch.Tensor], |
| 117 | + *, |
| 118 | + warmup: int, |
| 119 | + iters: int, |
| 120 | + device: torch.device, |
| 121 | +) -> float: |
| 122 | + with torch.inference_mode(): |
| 123 | + for _ in range(max(0, warmup)): |
| 124 | + model(**inputs) |
| 125 | + if device.type == "cuda": |
| 126 | + torch.cuda.synchronize() |
| 127 | + start = time.perf_counter() |
| 128 | + for _ in range(max(1, iters)): |
| 129 | + model(**inputs) |
| 130 | + if device.type == "cuda": |
| 131 | + torch.cuda.synchronize() |
| 132 | + return time.perf_counter() - start |
| 133 | + |
| 134 | + |
| 135 | +def main() -> None: |
| 136 | + args = _parse_args() |
| 137 | + if args.max_length <= 0: |
| 138 | + raise ValueError("--max_length must be positive") |
| 139 | + if args.batch_size <= 0: |
| 140 | + raise ValueError("--batch_size must be positive") |
| 141 | + |
| 142 | + device = torch.device(args.device) |
| 143 | + torch.manual_seed(1337) |
| 144 | + if device.type == "cuda": |
| 145 | + torch.cuda.manual_seed_all(1337) |
| 146 | + torch.backends.cudnn.deterministic = True |
| 147 | + |
| 148 | + tokenizer = AutoTokenizer.from_pretrained(args.model) |
| 149 | + model = AutoModel.from_pretrained(args.model) |
| 150 | + model.to(device) |
| 151 | + model.eval() |
| 152 | + |
| 153 | + pad_token_id = tokenizer.pad_token_id |
| 154 | + if pad_token_id is None: |
| 155 | + pad_token_id = tokenizer.eos_token_id or 0 |
| 156 | + vocab_size = getattr(tokenizer, "vocab_size", len(tokenizer)) |
| 157 | + if vocab_size is None: |
| 158 | + vocab_size = len(tokenizer) |
| 159 | + vocab_size = int(vocab_size) |
| 160 | + if vocab_size <= 1: |
| 161 | + raise ValueError("Tokenizer vocabulary size must exceed 1") |
| 162 | + lengths = _generate_lengths(args) |
| 163 | + lengths = [min(length, args.max_length) for length in lengths] |
| 164 | + requests = _build_requests(lengths, vocab_size, pad_token_id) |
| 165 | + real_tokens = sum(len(req["input_ids"]) for req in requests) |
| 166 | + |
| 167 | + baseline_inputs = _collate_baseline(requests, pad_token_id) |
| 168 | + baseline_inputs = {k: v.to(device) for k, v in baseline_inputs.items()} |
| 169 | + |
| 170 | + cfg = InferencePackingConfig( |
| 171 | + max_length=args.max_length, |
| 172 | + sep_token_id=tokenizer.sep_token_id, |
| 173 | + streams_per_batch=1, |
| 174 | + ) |
| 175 | + packed = pack_requests(requests, cfg, pad_token_id) |
| 176 | + mask_dtype = baseline_inputs["attention_mask"].dtype |
| 177 | + packed_inputs = { |
| 178 | + "input_ids": packed.input_ids.to(device), |
| 179 | + "attention_mask": packed.pair_attention_mask.to(device=device, dtype=mask_dtype), |
| 180 | + } |
| 181 | + |
| 182 | + warmup = args.warmup |
| 183 | + iters = args.iters |
| 184 | + |
| 185 | + baseline_time = _measure(model, baseline_inputs, warmup=warmup, iters=iters, device=device) |
| 186 | + packed_time = _measure(model, packed_inputs, warmup=warmup, iters=iters, device=device) |
| 187 | + |
| 188 | + padded_tokens = baseline_inputs["input_ids"].size(1) * len(requests) |
| 189 | + baseline_stats = BenchmarkStats( |
| 190 | + tokens_per_s=(real_tokens * iters) / baseline_time, |
| 191 | + examples_per_s=(len(requests) * iters) / baseline_time, |
| 192 | + padding_ratio=1.0 - (real_tokens / padded_tokens) if padded_tokens else 0.0, |
| 193 | + ) |
| 194 | + |
| 195 | + packed_tokens = packed.input_ids.size(1) * packed.input_ids.size(0) |
| 196 | + packed_stats = BenchmarkStats( |
| 197 | + tokens_per_s=(real_tokens * iters) / packed_time, |
| 198 | + examples_per_s=(len(requests) * iters) / packed_time, |
| 199 | + padding_ratio=1.0 - (real_tokens / packed_tokens) if packed_tokens else 0.0, |
| 200 | + ) |
| 201 | + |
| 202 | + result = { |
| 203 | + "device": device.type, |
| 204 | + "model": args.model, |
| 205 | + "scenario": args.scenario, |
| 206 | + "batch_size": args.batch_size, |
| 207 | + "max_length": args.max_length, |
| 208 | + "baseline": baseline_stats, |
| 209 | + "packed": packed_stats, |
| 210 | + "speedup_tokens_per_s": packed_stats.tokens_per_s / baseline_stats.tokens_per_s, |
| 211 | + "streams": packed.input_ids.size(0), |
| 212 | + } |
| 213 | + |
| 214 | + json_payload = { |
| 215 | + **{k: v for k, v in result.items() if k not in {"baseline", "packed"}}, |
| 216 | + "baseline": asdict(baseline_stats), |
| 217 | + "packed": asdict(packed_stats), |
| 218 | + } |
| 219 | + |
| 220 | + print(json.dumps(json_payload, indent=2)) |
| 221 | + print() |
| 222 | + print(_format_table(result)) |
| 223 | + |
| 224 | + |
| 225 | +if __name__ == "__main__": |
| 226 | + main() |
| 227 | + |
0 commit comments