Skip to content

Commit 5f34a50

Browse files
njhilljoerunde
authored andcommitted
Initial gRPC server and TGIS proto API mapping layer
Signed-off-by: Joe Runde <[email protected]>
1 parent b35cc93 commit 5f34a50

File tree

10 files changed

+947
-12
lines changed

10 files changed

+947
-12
lines changed

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
target_path := "vllm/entrypoints/grpc/pb"
3+
4+
gen-protos:
5+
# Compile protos
6+
pip install grpcio-tools==1.60.1 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
7+
mkdir $(target_path) || true
8+
python -m grpc_tools.protoc -Iproto --python_out=$(target_path) \
9+
--grpc_python_out=$(target_path) --mypy_out=$(target_path) proto/generation.proto
10+
find $(target_path)/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
11+
touch $(target_path)/__init__.py
12+

proto/generation.proto

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
Internal service interface for FMaaS completions
3+
*/
4+
5+
syntax = "proto3";
6+
package fmaas;
7+
8+
9+
service GenerationService {
10+
// Generates text given a text prompt, for one or more inputs
11+
rpc Generate (BatchedGenerationRequest) returns (BatchedGenerationResponse) {}
12+
// Generates text given a single input prompt, streaming the response
13+
rpc GenerateStream (SingleGenerationRequest) returns (stream GenerationResponse) {}
14+
// Tokenize text
15+
rpc Tokenize (BatchedTokenizeRequest) returns (BatchedTokenizeResponse) {}
16+
// Model info
17+
rpc ModelInfo (ModelInfoRequest) returns (ModelInfoResponse) {}
18+
}
19+
20+
// ============================================================================================================
21+
// Generation API
22+
23+
enum DecodingMethod {
24+
GREEDY = 0;
25+
SAMPLE = 1;
26+
}
27+
28+
message BatchedGenerationRequest {
29+
string model_id = 1;
30+
optional string prefix_id = 2;
31+
repeated GenerationRequest requests = 3;
32+
33+
Parameters params = 10;
34+
}
35+
36+
message SingleGenerationRequest {
37+
string model_id = 1;
38+
optional string prefix_id = 2;
39+
GenerationRequest request = 3;
40+
41+
Parameters params = 10;
42+
}
43+
44+
message BatchedGenerationResponse {
45+
repeated GenerationResponse responses = 1;
46+
}
47+
48+
message GenerationRequest {
49+
string text = 2;
50+
}
51+
52+
message GenerationResponse {
53+
uint32 input_token_count = 6;
54+
uint32 generated_token_count = 2;
55+
string text = 4;
56+
StopReason stop_reason = 7;
57+
// The stop sequence encountered, iff stop_reason == STOP_SEQUENCE
58+
string stop_sequence = 11;
59+
// Random seed used, not applicable for greedy requests
60+
uint64 seed = 10;
61+
62+
// Individual generated tokens and associated details, if requested
63+
repeated TokenInfo tokens = 8;
64+
65+
// Input tokens and associated details, if requested
66+
repeated TokenInfo input_tokens = 9;
67+
}
68+
69+
message Parameters {
70+
// The high level decoding approach
71+
DecodingMethod method = 1;
72+
// Parameters related to sampling, applicable only when method == SAMPLING
73+
SamplingParameters sampling = 2;
74+
// Parameters controlling when generation should stop
75+
StoppingCriteria stopping = 3;
76+
// Flags to control what is returned in the response
77+
ResponseOptions response = 4;
78+
// Parameters for conditionally penalizing/boosting
79+
// candidate tokens during decoding
80+
DecodingParameters decoding = 5;
81+
// Truncate to this many input tokens. Can be used to avoid requests
82+
// failing due to input being longer than configured limits.
83+
// Zero means don't truncate.
84+
uint32 truncate_input_tokens = 6;
85+
}
86+
87+
message DecodingParameters {
88+
message LengthPenalty {
89+
// Start the decay after this number of tokens have been generated
90+
uint32 start_index = 1;
91+
// Factor of exponential decay
92+
float decay_factor = 2;
93+
}
94+
95+
// Default (0.0) means no penalty (equivalent to 1.0)
96+
// 1.2 is a recommended value
97+
float repetition_penalty = 1;
98+
99+
// Exponentially increases the score of the EOS token
100+
// once start_index tokens have been generated
101+
optional LengthPenalty length_penalty = 2;
102+
}
103+
104+
105+
message SamplingParameters {
106+
// Default (0.0) means disabled (equivalent to 1.0)
107+
float temperature = 1;
108+
// Default (0) means disabled
109+
uint32 top_k = 2;
110+
// Default (0) means disabled (equivalent to 1.0)
111+
float top_p = 3;
112+
// Default (0) means disabled (equivalent to 1.0)
113+
float typical_p = 4;
114+
115+
optional uint64 seed = 5;
116+
}
117+
118+
message StoppingCriteria {
119+
// Default (0) is currently 20
120+
uint32 max_new_tokens = 1;
121+
// Default (0) means no minimum
122+
uint32 min_new_tokens = 2;
123+
// Default (0) means no time limit
124+
uint32 time_limit_millis = 3;
125+
repeated string stop_sequences = 4;
126+
// If not specified, default behavior depends on server setting
127+
optional bool include_stop_sequence = 5;
128+
129+
//more to come
130+
}
131+
132+
message ResponseOptions {
133+
// Include input text
134+
bool input_text = 1;
135+
// Include list of individual generated tokens
136+
// "Extra" token information is included based on the other flags below
137+
bool generated_tokens = 2;
138+
// Include list of input tokens
139+
// "Extra" token information is included based on the other flags here,
140+
// but only for decoder-only models
141+
bool input_tokens = 3;
142+
// Include logprob for each returned token
143+
// Applicable only if generated_tokens == true and/or input_tokens == true
144+
bool token_logprobs = 4;
145+
// Include rank of each returned token
146+
// Applicable only if generated_tokens == true and/or input_tokens == true
147+
bool token_ranks = 5;
148+
// Include top n candidate tokens at the position of each returned token
149+
// The maximum value permitted is 5, but more may be returned if there is a tie
150+
// for nth place.
151+
// Applicable only if generated_tokens == true and/or input_tokens == true
152+
uint32 top_n_tokens = 6;
153+
}
154+
155+
enum StopReason {
156+
// Possibly more tokens to be streamed
157+
NOT_FINISHED = 0;
158+
// Maximum requested tokens reached
159+
MAX_TOKENS = 1;
160+
// End-of-sequence token encountered
161+
EOS_TOKEN = 2;
162+
// Request cancelled by client
163+
CANCELLED = 3;
164+
// Time limit reached
165+
TIME_LIMIT = 4;
166+
// Stop sequence encountered
167+
STOP_SEQUENCE = 5;
168+
// Total token limit reached
169+
TOKEN_LIMIT = 6;
170+
// Decoding error
171+
ERROR = 7;
172+
}
173+
174+
message TokenInfo {
175+
// uint32 id = 1; // TBD
176+
string text = 2;
177+
// The logprob (log of normalized probability), if requested
178+
float logprob = 3;
179+
// One-based rank relative to other tokens, if requested
180+
uint32 rank = 4;
181+
182+
message TopToken {
183+
// uint32 id = 1; // TBD
184+
string text = 2;
185+
float logprob = 3;
186+
}
187+
188+
// Top N candidate tokens at this position, if requested
189+
// May or may not include this token
190+
repeated TopToken top_tokens = 5;
191+
}
192+
193+
194+
// ============================================================================================================
195+
// Tokenization API
196+
197+
message BatchedTokenizeRequest {
198+
string model_id = 1;
199+
repeated TokenizeRequest requests = 2;
200+
bool return_tokens = 3; //TBD
201+
}
202+
203+
message BatchedTokenizeResponse {
204+
repeated TokenizeResponse responses = 1;
205+
}
206+
207+
message TokenizeRequest {
208+
string text = 1;
209+
}
210+
211+
message TokenizeResponse {
212+
uint32 token_count = 1;
213+
repeated string tokens = 2; // if include_tokens = true
214+
215+
// We'll possibly add more later
216+
}
217+
218+
219+
// ============================================================================================================
220+
// Model Info API
221+
222+
message ModelInfoRequest {
223+
string model_id = 1;
224+
}
225+
226+
message ModelInfoResponse {
227+
enum ModelKind {
228+
DECODER_ONLY = 0;
229+
ENCODER_DECODER = 1;
230+
}
231+
232+
ModelKind model_kind = 1;
233+
uint32 max_sequence_length = 2;
234+
uint32 max_new_tokens = 3;
235+
}

vllm/engine/llm_engine.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,21 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
994994
def _check_stop(self, seq: Sequence,
995995
sampling_params: SamplingParams) -> None:
996996
"""Stop the finished sequences."""
997+
# Check if the sequence has reached max_model_len.
998+
if seq.get_len() > self.scheduler_config.max_model_len:
999+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
1000+
return
1001+
1002+
# Check if the sequence has reached max_tokens.
1003+
if seq.get_output_len() == sampling_params.max_tokens:
1004+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
1005+
return
1006+
1007+
# Check if the minimum number of tokens has been generated yet;
1008+
# skip the stop string/token checks if not
1009+
if seq.get_output_len() < sampling_params.min_tokens:
1010+
return
1011+
9971012
for stop_str in sampling_params.stop:
9981013
if seq.output_text.endswith(stop_str):
9991014
self._finalize_sequence(seq, sampling_params, stop_str)
@@ -1006,16 +1021,6 @@ def _check_stop(self, seq: Sequence,
10061021
seq.status = SequenceStatus.FINISHED_STOPPED
10071022
return
10081023

1009-
# Check if the sequence has reached max_model_len.
1010-
if seq.get_len() > self.scheduler_config.max_model_len:
1011-
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
1012-
return
1013-
1014-
# Check if the sequence has reached max_tokens.
1015-
if seq.get_output_len() == sampling_params.max_tokens:
1016-
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
1017-
return
1018-
10191024
# Check if the sequence has generated the EOS token.
10201025
if ((not sampling_params.ignore_eos)
10211026
and seq.get_last_token_id() == seq.eos_token_id):

0 commit comments

Comments
 (0)