Skip to content

Commit 655a5e4

Browse files
authored
Introduce LLM class for offline inference (#115)
1 parent f746ced commit 655a5e4

File tree

9 files changed

+221
-80
lines changed

9 files changed

+221
-80
lines changed

cacheflow/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
1+
from cacheflow.entrypoints.llm import LLM
12
from cacheflow.outputs import RequestOutput
23
from cacheflow.sampling_params import SamplingParams
3-
from cacheflow.server.arg_utils import (
4-
add_server_arguments,
5-
create_server_configs_from_args,
6-
initialize_server_from_args,
7-
)
4+
from cacheflow.server.arg_utils import ServerArgs
85
from cacheflow.server.llm_server import LLMServer
96
from cacheflow.server.ray_utils import initialize_cluster
107

118
__all__ = [
12-
"RequestOutput",
9+
"LLM",
1310
"SamplingParams",
11+
"RequestOutput",
1412
"LLMServer",
15-
"add_server_arguments",
16-
"create_server_configs_from_args",
17-
"initialize_server_from_args",
13+
"ServerArgs",
1814
"initialize_cluster",
1915
]

cacheflow/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from transformers import AutoConfig, PretrainedConfig
55

6+
_GiB = 1 << 30
7+
68

79
class ModelConfig:
810

@@ -70,7 +72,7 @@ def __init__(
7072
) -> None:
7173
self.block_size = block_size
7274
self.gpu_memory_utilization = gpu_memory_utilization
73-
self.swap_space = swap_space
75+
self.swap_space_bytes = swap_space * _GiB
7476

7577
# Will be set after profiling.
7678
self.num_gpu_blocks = None
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
138140
else:
139141
torch_dtype = config_dtype
140142
else:
143+
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
144+
raise ValueError(f"Unknown dtype: {dtype}")
141145
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
142146

143147
# Verify the dtype.

cacheflow/entrypoints/fastapi_server.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
from cacheflow.outputs import RequestOutput
1414
from cacheflow.sampling_params import SamplingParams
15-
from cacheflow.server.arg_utils import (
16-
add_server_arguments, create_server_configs_from_args)
15+
from cacheflow.server.arg_utils import ServerArgs
1716
from cacheflow.server.llm_server import LLMServer
1817
from cacheflow.server.ray_utils import initialize_cluster
1918

@@ -116,10 +115,10 @@ async def generate_stream(request: Request):
116115
parser = argparse.ArgumentParser()
117116
parser.add_argument("--host", type=str, default="localhost")
118117
parser.add_argument("--port", type=int, default=10002)
119-
parser = add_server_arguments(parser)
118+
parser = ServerArgs.add_cli_args(parser)
120119
args = parser.parse_args()
121120

122-
server_configs = create_server_configs_from_args(args)
121+
server_configs = ServerArgs.from_cli_args(args).create_server_configs()
123122
parallel_config = server_configs[2]
124123
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
125124

cacheflow/entrypoints/llm.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import List, Optional
2+
3+
from tqdm import tqdm
4+
5+
from cacheflow.outputs import RequestOutput
6+
from cacheflow.sampling_params import SamplingParams
7+
from cacheflow.server.arg_utils import ServerArgs
8+
from cacheflow.server.llm_server import LLMServer
9+
from cacheflow.utils import Counter
10+
11+
12+
class LLM:
13+
14+
def __init__(
15+
self,
16+
model: str,
17+
tensor_parallel_size: int = 1,
18+
dtype: str = "default",
19+
seed: int = 0,
20+
**kwargs,
21+
) -> None:
22+
if "disable_log_stats" not in kwargs:
23+
kwargs["disable_log_stats"] = True
24+
server_args = ServerArgs(
25+
model=model,
26+
tensor_parallel_size=tensor_parallel_size,
27+
dtype=dtype,
28+
seed=seed,
29+
**kwargs,
30+
)
31+
self.llm_server = LLMServer.from_server_args(server_args)
32+
self.request_counter = Counter()
33+
34+
def generate(
35+
self,
36+
prompts: List[str],
37+
sampling_params: Optional[SamplingParams] = None,
38+
use_tqdm: bool = True,
39+
) -> List[RequestOutput]:
40+
if sampling_params is None:
41+
sampling_params = SamplingParams()
42+
# Initialize tqdm.
43+
if use_tqdm:
44+
pbar = tqdm(total=len(prompts), desc="Processed prompts")
45+
46+
# Add requests to the server.
47+
for prompt in prompts:
48+
request_id = str(next(self.request_counter))
49+
self.llm_server.add_request(request_id, prompt, sampling_params)
50+
51+
# Run the server.
52+
outputs: List[RequestOutput] = []
53+
while self.llm_server.has_unfinished_requests():
54+
step_outputs = self.llm_server.step()
55+
for output in step_outputs:
56+
if output.done:
57+
outputs.append(output)
58+
if use_tqdm:
59+
pbar.update(1)
60+
if use_tqdm:
61+
pbar.close()
62+
return outputs

cacheflow/outputs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def __init__(
3535
prompt: str,
3636
prompt_token_ids: List[int],
3737
outputs: List[CompletionOutput],
38-
done: bool = False,
38+
done: bool,
3939
) -> None:
4040
self.request_id = request_id
4141
self.prompt = prompt
4242
self.prompt_token_ids = prompt_token_ids
4343
self.outputs = outputs
4444
self.done = done
4545

46-
@staticmethod
47-
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
46+
@classmethod
47+
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
4848
# Get the top-n sequences.
4949
n = seq_group.sampling_params.n
5050
seqs = seq_group.get_seqs()
@@ -70,8 +70,8 @@ def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
7070
# Every sequence in the sequence group should have the same prompt.
7171
prompt = top_n_seqs[0].prompt
7272
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
73-
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
74-
outputs, seq_group.is_finished())
73+
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
74+
seq_group.is_finished())
7575

7676
def __repr__(self) -> str:
7777
return (f"RequestOutput(request_id={self.request_id}, "

cacheflow/server/arg_utils.py

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,117 @@
11
import argparse
2-
from typing import Tuple
2+
import dataclasses
3+
from dataclasses import dataclass
4+
from typing import Optional, Tuple
35

46
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
57
SchedulerConfig)
6-
from cacheflow.server.llm_server import LLMServer
7-
from cacheflow.server.ray_utils import initialize_cluster
88

9-
_GiB = 1 << 30
109

10+
@dataclass
11+
class ServerArgs:
12+
model: str
13+
download_dir: Optional[str] = None
14+
use_np_weights: bool = False
15+
use_dummy_weights: bool = False
16+
dtype: str = "default"
17+
seed: int = 0
18+
use_ray: bool = False
19+
pipeline_parallel_size: int = 1
20+
tensor_parallel_size: int = 1
21+
block_size: int = 16
22+
swap_space: int = 4 # GiB
23+
gpu_memory_utilization: float = 0.95
24+
max_num_batched_tokens: int = 2560
25+
max_num_seqs: int = 256
26+
disable_log_stats: bool = False
1127

12-
def add_server_arguments(parser: argparse.ArgumentParser):
13-
"""Shared arguments for CacheFlow servers."""
28+
def __post_init__(self):
29+
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
30+
31+
@staticmethod
32+
def add_cli_args(
33+
parser: argparse.ArgumentParser,
34+
) -> argparse.ArgumentParser:
35+
return _add_server_arguments(parser)
36+
37+
@classmethod
38+
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
39+
# Get the list of attributes of this dataclass.
40+
attrs = [attr.name for attr in dataclasses.fields(cls)]
41+
# Set the attributes from the parsed arguments.
42+
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
43+
return server_args
44+
45+
def create_server_configs(
46+
self,
47+
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
48+
# Initialize the configs.
49+
model_config = ModelConfig(
50+
self.model, self.download_dir, self.use_np_weights,
51+
self.use_dummy_weights, self.dtype, self.seed)
52+
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
53+
self.swap_space)
54+
parallel_config = ParallelConfig(self.pipeline_parallel_size,
55+
self.tensor_parallel_size,
56+
self.use_ray)
57+
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
58+
self.max_num_seqs)
59+
return model_config, cache_config, parallel_config, scheduler_config
60+
61+
62+
def _add_server_arguments(
63+
parser: argparse.ArgumentParser,
64+
)-> argparse.ArgumentParser:
65+
"""Shared CLI arguments for CacheFlow servers."""
1466
# Model arguments
15-
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
16-
parser.add_argument('--download-dir', type=str, default=None,
67+
parser.add_argument('--model', type=str, default='facebook/opt-125m',
68+
help='name or path of the huggingface model to use')
69+
parser.add_argument('--download-dir', type=str,
70+
default=ServerArgs.download_dir,
1771
help='directory to download and load the weights, '
1872
'default to the default cache dir of huggingface')
1973
parser.add_argument('--use-np-weights', action='store_true',
20-
help='save a numpy copy of model weights for faster loading')
21-
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
74+
help='save a numpy copy of model weights for faster '
75+
'loading. This can increase the disk usage by up '
76+
'to 2x.')
77+
parser.add_argument('--use-dummy-weights', action='store_true',
78+
help='use dummy values for model weights')
2279
# TODO(woosuk): Support FP32.
23-
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
80+
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
81+
choices=['default', 'half', 'bfloat16'],
2482
help=('data type for model weights and activations. '
2583
'The "default" option will use FP16 precision '
2684
'for FP32 and FP16 models, and BF16 precision '
2785
'for BF16 models.'))
2886
# Parallel arguments
29-
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
30-
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
31-
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
87+
parser.add_argument('--use-ray', action='store_true',
88+
help='use Ray for distributed serving, will be '
89+
'automatically set when using more than 1 GPU')
90+
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
91+
default=ServerArgs.pipeline_parallel_size,
92+
help='number of pipeline stages')
93+
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
94+
default=ServerArgs.tensor_parallel_size,
95+
help='number of tensor parallel replicas')
3296
# KV cache arguments
33-
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
97+
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
98+
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
99+
help='token block size')
34100
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
35-
parser.add_argument('--seed', type=int, default=0, help='random seed')
36-
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
37-
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
38-
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
39-
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
40-
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
101+
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
102+
help='random seed')
103+
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
104+
help='CPU swap space size (GiB) per GPU')
105+
parser.add_argument('--gpu-memory-utilization', type=float,
106+
default=ServerArgs.gpu_memory_utilization,
107+
help='the percentage of GPU memory to be used for the '
108+
'model executor')
109+
parser.add_argument('--max-num-batched-tokens', type=int,
110+
default=ServerArgs.max_num_batched_tokens,
111+
help='maximum number of batched tokens per iteration')
112+
parser.add_argument('--max-num-seqs', type=int,
113+
default=ServerArgs.max_num_seqs,
114+
help='maximum number of sequences per iteration')
115+
parser.add_argument('--disable-log-stats', action='store_true',
116+
help='disable logging statistics')
41117
return parser
42-
43-
44-
def create_server_configs_from_args(
45-
args: argparse.Namespace,
46-
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
47-
# Post-process the parsed arguments.
48-
args.swap_space = args.swap_space * _GiB
49-
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
50-
51-
# Initialize the configs.
52-
model_config = ModelConfig(
53-
args.model, args.download_dir, args.use_np_weights,
54-
args.use_dummy_weights, args.dtype, args.seed)
55-
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
56-
args.swap_space)
57-
parallel_config = ParallelConfig(args.pipeline_parallel_size,
58-
args.tensor_parallel_size, args.use_ray)
59-
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
60-
args.max_num_seqs)
61-
return model_config, cache_config, parallel_config, scheduler_config
62-
63-
64-
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
65-
server_configs = create_server_configs_from_args(args)
66-
parallel_config = server_configs[2]
67-
68-
# Initialize the cluster.
69-
distributed_init_method, devices = initialize_cluster(parallel_config)
70-
71-
# Create the LLM server.
72-
server = LLMServer(*server_configs, distributed_init_method, devices,
73-
log_stats=not args.disable_log_stats)
74-
return server

cacheflow/server/llm_server.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from cacheflow.logger import init_logger
1313
from cacheflow.outputs import RequestOutput
1414
from cacheflow.sampling_params import SamplingParams
15+
from cacheflow.server.arg_utils import ServerArgs
16+
from cacheflow.server.ray_utils import initialize_cluster
1517
from cacheflow.server.tokenizer_utils import get_tokenizer
1618
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
1719
from cacheflow.utils import Counter
@@ -30,7 +32,7 @@ def __init__(
3032
scheduler_config: SchedulerConfig,
3133
distributed_init_method: str,
3234
stage_devices: List[List[Any]],
33-
log_stats: bool = True,
35+
log_stats: bool,
3436
) -> None:
3537
logger.info(
3638
"Initializing an LLM server with config: "
@@ -90,7 +92,7 @@ def _init_cache(self) -> None:
9092
get_all_outputs=True,
9193
block_size=self.cache_config.block_size,
9294
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
93-
cpu_swap_space=self.cache_config.swap_space,
95+
cpu_swap_space=self.cache_config.swap_space_bytes,
9496
)
9597

9698
# Since we use a shared centralized controller, we take the minimum
@@ -107,6 +109,18 @@ def _init_cache(self) -> None:
107109
# Initialize the cache.
108110
self._run_workers("init_cache_engine", cache_config=self.cache_config)
109111

112+
@classmethod
113+
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
114+
# Create the server configs.
115+
server_configs = server_args.create_server_configs()
116+
parallel_config = server_configs[2]
117+
# Initialize the cluster.
118+
distributed_init_method, devices = initialize_cluster(parallel_config)
119+
# Create the LLM server.
120+
server = cls(*server_configs, distributed_init_method, devices,
121+
log_stats=not server_args.disable_log_stats)
122+
return server
123+
110124
def add_request(
111125
self,
112126
request_id: str,

0 commit comments

Comments
 (0)