Skip to content

Commit 07e6f3f

Browse files
Merge pull request #1
Add inference-time sequence packing support
2 parents 4552dda + 7125fc3 commit 07e6f3f

File tree

7 files changed

+988
-8
lines changed

7 files changed

+988
-8
lines changed

bench/bench_infer_packing.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#!/usr/bin/env python
2+
"""Benchmark harness for GLiNER inference-time packing."""
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import time
9+
from dataclasses import asdict, dataclass
10+
from typing import Dict, List
11+
12+
import numpy as np
13+
import torch
14+
from transformers import AutoModel, AutoTokenizer
15+
16+
from gliner.infer_packing import InferencePackingConfig, pack_requests
17+
18+
19+
@dataclass
20+
class BenchmarkStats:
21+
tokens_per_s: float
22+
examples_per_s: float
23+
padding_ratio: float
24+
25+
26+
def _format_table(result: Dict[str, object]) -> str:
27+
lines = []
28+
header = f"{'mode':<10} {'tokens/s':>15} {'examples/s':>15} {'padding':>12}"
29+
lines.append(header)
30+
lines.append("-" * len(header))
31+
for mode in ("baseline", "packed"):
32+
stats: BenchmarkStats = result[mode] # type: ignore[assignment]
33+
lines.append(
34+
f"{mode:<10} {stats.tokens_per_s:>15.2e} {stats.examples_per_s:>15.2f} {stats.padding_ratio:>11.2%}"
35+
)
36+
lines.append("")
37+
lines.append(f"Speedup (tokens/s): {result['speedup_tokens_per_s']:.2f}x")
38+
return "\n".join(lines)
39+
40+
41+
def _parse_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser(description="Benchmark GLiNER inference packing.")
43+
parser.add_argument("--model", type=str, default="roberta-base", help="Model name or path")
44+
parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
45+
parser.add_argument("--batch_size", type=int, default=64, help="Number of requests per batch")
46+
parser.add_argument(
47+
"--scenario",
48+
type=str,
49+
default="short_uniform",
50+
choices=["short_uniform", "short_zipf", "mixed_tail", "flat_long"],
51+
help="Length distribution scenario",
52+
)
53+
parser.add_argument("--device", type=str, default="cpu", help="Device to benchmark on")
54+
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations")
55+
parser.add_argument("--iters", type=int, default=100, help="Number of timed iterations")
56+
return parser.parse_args()
57+
58+
59+
def _generate_lengths(args: argparse.Namespace) -> List[int]:
60+
batch = args.batch_size
61+
max_length = args.max_length
62+
63+
if args.scenario == "short_uniform":
64+
rng = np.random.default_rng(1337)
65+
values = rng.integers(8, min(64, max_length) + 1, size=batch)
66+
return values.astype(int).tolist()
67+
if args.scenario == "short_zipf":
68+
rng = np.random.default_rng(2024)
69+
lengths = rng.zipf(1.2, size=batch)
70+
clipped = np.clip(lengths, 8, min(128, max_length))
71+
return clipped.astype(int).tolist()
72+
if args.scenario == "mixed_tail":
73+
rng = np.random.default_rng(314)
74+
longs = [min(256, max_length)]
75+
if batch > 1:
76+
shorts = rng.integers(8, min(48, max_length) + 1, size=batch - 1)
77+
return longs + shorts.astype(int).tolist()
78+
return longs
79+
if args.scenario == "flat_long":
80+
return [min(256, max_length)] * batch
81+
82+
raise ValueError(f"Unsupported scenario: {args.scenario}")
83+
84+
85+
def _build_requests(lengths: List[int], vocab_size: int, pad_token_id: int) -> List[Dict[str, List[int]]]:
86+
requests: List[Dict[str, List[int]]] = []
87+
token = 0
88+
for length in lengths:
89+
actual_len = max(1, min(int(length), vocab_size - 1))
90+
sequence: List[int] = []
91+
for _ in range(actual_len):
92+
value = token % vocab_size
93+
if value == pad_token_id:
94+
value = (value + 1) % vocab_size
95+
sequence.append(value)
96+
token += 1
97+
requests.append({"input_ids": sequence})
98+
return requests
99+
100+
101+
def _collate_baseline(requests: List[Dict[str, List[int]]], pad_token_id: int) -> Dict[str, torch.Tensor]:
102+
max_len = max(len(req["input_ids"]) for req in requests)
103+
batch = len(requests)
104+
input_ids = torch.full((batch, max_len), pad_token_id, dtype=torch.long)
105+
attention_mask = torch.zeros((batch, max_len), dtype=torch.long)
106+
for row, req in enumerate(requests):
107+
tokens = req["input_ids"]
108+
length = len(tokens)
109+
input_ids[row, :length] = torch.tensor(tokens, dtype=torch.long)
110+
attention_mask[row, :length] = 1
111+
return {"input_ids": input_ids, "attention_mask": attention_mask}
112+
113+
114+
def _measure(
115+
model: AutoModel,
116+
inputs: Dict[str, torch.Tensor],
117+
*,
118+
warmup: int,
119+
iters: int,
120+
device: torch.device,
121+
) -> float:
122+
with torch.inference_mode():
123+
for _ in range(max(0, warmup)):
124+
model(**inputs)
125+
if device.type == "cuda":
126+
torch.cuda.synchronize()
127+
start = time.perf_counter()
128+
for _ in range(max(1, iters)):
129+
model(**inputs)
130+
if device.type == "cuda":
131+
torch.cuda.synchronize()
132+
return time.perf_counter() - start
133+
134+
135+
def main() -> None:
136+
args = _parse_args()
137+
if args.max_length <= 0:
138+
raise ValueError("--max_length must be positive")
139+
if args.batch_size <= 0:
140+
raise ValueError("--batch_size must be positive")
141+
142+
device = torch.device(args.device)
143+
torch.manual_seed(1337)
144+
if device.type == "cuda":
145+
torch.cuda.manual_seed_all(1337)
146+
torch.backends.cudnn.deterministic = True
147+
148+
tokenizer = AutoTokenizer.from_pretrained(args.model)
149+
model = AutoModel.from_pretrained(args.model)
150+
model.to(device)
151+
model.eval()
152+
153+
pad_token_id = tokenizer.pad_token_id
154+
if pad_token_id is None:
155+
pad_token_id = tokenizer.eos_token_id or 0
156+
vocab_size = getattr(tokenizer, "vocab_size", len(tokenizer))
157+
if vocab_size is None:
158+
vocab_size = len(tokenizer)
159+
vocab_size = int(vocab_size)
160+
if vocab_size <= 1:
161+
raise ValueError("Tokenizer vocabulary size must exceed 1")
162+
lengths = _generate_lengths(args)
163+
lengths = [min(length, args.max_length) for length in lengths]
164+
requests = _build_requests(lengths, vocab_size, pad_token_id)
165+
real_tokens = sum(len(req["input_ids"]) for req in requests)
166+
167+
baseline_inputs = _collate_baseline(requests, pad_token_id)
168+
baseline_inputs = {k: v.to(device) for k, v in baseline_inputs.items()}
169+
170+
cfg = InferencePackingConfig(
171+
max_length=args.max_length,
172+
sep_token_id=tokenizer.sep_token_id,
173+
streams_per_batch=1,
174+
)
175+
packed = pack_requests(requests, cfg, pad_token_id)
176+
mask_dtype = baseline_inputs["attention_mask"].dtype
177+
packed_inputs = {
178+
"input_ids": packed.input_ids.to(device),
179+
"attention_mask": packed.pair_attention_mask.to(device=device, dtype=mask_dtype),
180+
}
181+
182+
warmup = args.warmup
183+
iters = args.iters
184+
185+
baseline_time = _measure(model, baseline_inputs, warmup=warmup, iters=iters, device=device)
186+
packed_time = _measure(model, packed_inputs, warmup=warmup, iters=iters, device=device)
187+
188+
padded_tokens = baseline_inputs["input_ids"].size(1) * len(requests)
189+
baseline_stats = BenchmarkStats(
190+
tokens_per_s=(real_tokens * iters) / baseline_time,
191+
examples_per_s=(len(requests) * iters) / baseline_time,
192+
padding_ratio=1.0 - (real_tokens / padded_tokens) if padded_tokens else 0.0,
193+
)
194+
195+
packed_tokens = packed.input_ids.size(1) * packed.input_ids.size(0)
196+
packed_stats = BenchmarkStats(
197+
tokens_per_s=(real_tokens * iters) / packed_time,
198+
examples_per_s=(len(requests) * iters) / packed_time,
199+
padding_ratio=1.0 - (real_tokens / packed_tokens) if packed_tokens else 0.0,
200+
)
201+
202+
result = {
203+
"device": device.type,
204+
"model": args.model,
205+
"scenario": args.scenario,
206+
"batch_size": args.batch_size,
207+
"max_length": args.max_length,
208+
"baseline": baseline_stats,
209+
"packed": packed_stats,
210+
"speedup_tokens_per_s": packed_stats.tokens_per_s / baseline_stats.tokens_per_s,
211+
"streams": packed.input_ids.size(0),
212+
}
213+
214+
json_payload = {
215+
**{k: v for k, v in result.items() if k not in {"baseline", "packed"}},
216+
"baseline": asdict(baseline_stats),
217+
"packed": asdict(packed_stats),
218+
}
219+
220+
print(json.dumps(json_payload, indent=2))
221+
print()
222+
print(_format_table(result))
223+
224+
225+
if __name__ == "__main__":
226+
main()
227+

gliner/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
__version__ = "0.2.22"
22

33
from .model import GLiNER
4+
from .infer_packing import (
5+
InferencePackingConfig,
6+
PackedBatch,
7+
pack_requests,
8+
unpack_spans,
9+
)
410
from .config import GLiNERConfig
511
# from .multitask import (GLiNERClassifier, GLiNERQuestionAnswerer, GLiNEROpenExtractor,
612
# GLiNERRelationExtractor, GLiNERSummarizer, GLiNERSquadEvaluator,
713
# GLiNERDocREDEvaluator)
814

9-
__all__ = ["GLiNER"]
15+
__all__ = [
16+
"GLiNER",
17+
"InferencePackingConfig",
18+
"PackedBatch",
19+
"pack_requests",
20+
"unpack_spans",
21+
]

0 commit comments

Comments
 (0)