Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 178 additions & 3 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
import logging
import math
import random
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator, Mapping
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from datasets import load_dataset
from functools import cache
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Any, cast
from typing import Any, Callable, List, Optional, Union, cast

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -63,6 +65,8 @@
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser

from rich.progress import Progress

logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -1288,6 +1292,119 @@
)
return samples

# -----------------------------------------------------------------------------
# Project Gutenberg Dataset Implementation
# -----------------------------------------------------------------------------


class GutenbergDataset(BenchmarkDataset):
"""
Implements the Gutenberg dataset. Loads data from a text file and generates
sample requests.
"""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()

def load_data(self) -> None:
"""Load data from HuggingFace datasets."""
self.data = load_dataset(self.dataset_path, split="en", streaming=True)
self.data = self.data.shuffle(seed=self.random_seed)

def clean_gutenberg_text(self, text: str) -> str:
"""
Basic cleaning for Project Gutenberg text:
- Extract content between "*** START OF ..." and "*** END OF ..." markers (if available)

Check failure on line 1318 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:1318:89: E501 Line too long (96 > 88)
- Normalize whitespace and multiple newlines
- Trim leading/trailing spaces
"""
# Extract content inside START/END markers
start_match = re.search(r"\*\*\* START OF.*?\*\*\*", text, re.IGNORECASE | re.DOTALL)

Check failure on line 1323 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:1323:89: E501 Line too long (93 > 88)
end_match = re.search(r"\*\*\* END OF.*?\*\*\*", text, re.IGNORECASE | re.DOTALL)

Check failure on line 1324 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:1324:89: E501 Line too long (91 > 88)

if start_match and end_match:
content = text[start_match.end(): end_match.start()]
else:
# Use entire raw text if markers are missing
content = text

# Normalize newlines
content = content.replace("\r\n", "\n").replace("\r", "\n")
# Reduce consecutive blank lines and spaces
content = re.sub(r"\n{2,}", "\n\n", content)
content = re.sub(r"[ \t]{2,}", " ", content)
# Strip spaces
content = content.strip()
return content

def chunk_tokens(self, input_ids: List[int], input_len: int = 4000) -> List[List[int]]:

Check failure on line 1341 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:1341:89: E501 Line too long (91 > 88)
"""
Split a list of token IDs into chunks of size `chunk_size`.
"""
return [
input_ids[i : i + input_len]
for i in range(0, len(input_ids), input_len)
]
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
input_len: int,
output_len: int,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list:
samples = []
ind = 0

pbar = Progress()
task_id = pbar.add_task(
description="Preparing input prompts...",
total=num_requests,
)
pbar.start()
for book_idx, book in enumerate(self.data):
if len(samples) >= num_requests:
break
text = book["text"]
text = self.clean_gutenberg_text(text)
input_ids = tokenizer(text).input_ids
chunks = self.chunk_tokens(input_ids, input_len)
for i in range(0, len(chunks)-1):
if len(samples) >= num_requests:
break
prompt = tokenizer.decode(chunks[i])
prompt_len = len(chunks[i])
# FOR DEBUG
#print(f"{prompt_len=}")
#print(f"{prompt=}\n")
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
request_id=request_id_prefix + str(ind),
))
pbar.update(task_id=task_id, advance=1)
ind += 1

completed = pbar.tasks[task_id].completed
if completed < num_requests:
print(
f"Not enough compatible requests ({completed}/{num_requests}). "
f"Start oversampling..."
)
pbar.update(task_id=task_id, advance=(num_requests-completed))
pbar.stop()
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
no_oversample)
return samples


class _ValidateDatasetArgs(argparse.Action):
"""Argparse action to validate dataset name and path compatibility."""
Expand Down Expand Up @@ -1333,6 +1450,7 @@
"custom",
"prefix_repetition",
"spec_bench",
"gutenberg"
],
help="Name of the dataset to benchmark on.",
)
Expand All @@ -1346,7 +1464,7 @@
type=str,
default=None,
action=_ValidateDatasetArgs,
help="Path to the sharegpt/sonnet dataset. "
help="Path to the sharegpt/sonnet/gutenberg dataset. "
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
Expand All @@ -1367,6 +1485,13 @@

# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
custom_group.add_argument(
"--custom-input-len",
type=int,
default=None,
help=
"Number of input tokens per request, used only for custom dataset.",
)
custom_group.add_argument(
"--custom-output-len",
type=int,
Expand Down Expand Up @@ -1417,6 +1542,22 @@
"from the ShareGPT dataset.",
)

gutenberg_group = parser.add_argument_group("gutenberg dataset options")
gutenberg_group.add_argument(
"--gutenberg-input-len",
type=int,
default=None,
help="Input length for each request. Overrides the input length "
"from the Gutenberg dataset.",
)
gutenberg_group.add_argument(
"--gutenberg-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the Gutenberg dataset.",
)

blazedit_group = parser.add_argument_group("blazedit dataset options")
blazedit_group.add_argument(
"--blazedit-min-distance",
Expand Down Expand Up @@ -1644,6 +1785,7 @@
input_requests = dataset.sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_len=args.custom_input_len,
output_len=args.custom_output_len,
skip_chat_template=args.skip_chat_template,
request_id_prefix=args.request_id_prefix,
Expand Down Expand Up @@ -1828,6 +1970,16 @@
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"gutenberg": lambda: GutenbergDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
input_len=args.gutenberg_input_len,
output_len=args.gutenberg_output_len,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed,
dataset_path=args.dataset_path,
Expand Down Expand Up @@ -1974,6 +2126,7 @@
num_requests: int,
lora_path: str | None = None,
max_loras: int | None = None,
input_len: int | None = None,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
Expand All @@ -1991,11 +2144,27 @@
num_requests,
)

pbar = Progress()
task_id = pbar.add_task(description="Preparing input prompts...", total=num_requests)

Check failure on line 2148 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:2148:89: E501 Line too long (93 > 88)
pbar.start()
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break

prompt = item["prompt"]
if prompt is None:
continue

prompt_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_ids)
if prompt_len >= input_len:
if (prompt_len < input_len):
continue
prompt = tokenizer.decode(prompt_ids[:input_len])
prompt_len = input_len
else:
continue

# apply template
if not skip_chat_template:
Expand All @@ -2004,8 +2173,8 @@
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)

prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
Expand All @@ -2014,6 +2183,12 @@
request_id=request_id_prefix + str(i),
)
)
pbar.update(task_id=task_id, advance=1)
completed = pbar.tasks[task_id].completed
if completed < num_requests:
print(f"Not enough compatible requests ({completed}/{num_requests}). Start oversampling...")

Check failure on line 2189 in vllm/benchmarks/datasets.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/datasets.py:2189:89: E501 Line too long (104 > 88)
pbar.update(task_id=task_id, advance=(num_requests-completed))
pbar.stop()
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix, no_oversample
)
Expand Down
Loading