|
32 | 32 | import numpy as np |
33 | 33 | import torch |
34 | 34 |
|
35 | | - from litserve import LitServer |
| 35 | + from litserve import LitAPI, LitServer |
36 | 36 |
|
37 | 37 |
|
38 | 38 | class EmbeddingRequest(BaseModel): |
@@ -129,44 +129,32 @@ def __init__(self): |
129 | 129 | self.add_endpoint("/v1/embeddings", self.embeddings_endpoint, ["POST"]) |
130 | 130 | self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"]) |
131 | 131 |
|
132 | | - def setup(self, server: "LitServer"): |
| 132 | + def pre_setup(self, lit_api: "LitAPI"): |
133 | 133 | from litserve import LitAPI |
134 | 134 |
|
135 | | - super().setup(server) |
136 | | - |
137 | | - lit_api = server.lit_api |
| 135 | + if inspect.isgeneratorfunction(lit_api.predict): |
| 136 | + raise ValueError( |
| 137 | + "You are using yield in your predict method, which is used for streaming.", |
| 138 | + "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", |
| 139 | + "is not a sequential operation.", |
| 140 | + "Please consider replacing yield with return in predict.\n", |
| 141 | + EMBEDDING_API_EXAMPLE, |
| 142 | + ) |
138 | 143 |
|
139 | | - if isinstance(lit_api, LitAPI): |
140 | | - self._check_lit_api(lit_api) |
141 | | - elif isinstance(lit_api, list): |
142 | | - for api in lit_api: |
143 | | - self._check_lit_api(api) |
| 144 | + is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__ |
| 145 | + if not is_encode_response_original and inspect.isgeneratorfunction(lit_api.encode_response): |
| 146 | + raise ValueError( |
| 147 | + "You are using yield in your encode_response method, which is used for streaming.", |
| 148 | + "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", |
| 149 | + "is not a sequential operation.", |
| 150 | + "Please consider replacing yield with return in encode_response.\n", |
| 151 | + EMBEDDING_API_EXAMPLE, |
| 152 | + ) |
144 | 153 |
|
| 154 | + def setup(self, server: "LitServer"): |
| 155 | + super().setup(server) |
145 | 156 | print("OpenAI Embedding Spec is ready.") |
146 | 157 |
|
147 | | - def _check_lit_api(self, api): |
148 | | - from litserve import LitAPI |
149 | | - |
150 | | - if isinstance(api.spec, OpenAIEmbeddingSpec): |
151 | | - if inspect.isgeneratorfunction(api.predict): |
152 | | - raise ValueError( |
153 | | - "You are using yield in your predict method, which is used for streaming.", |
154 | | - "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", |
155 | | - "is not a sequential operation.", |
156 | | - "Please consider replacing yield with return in predict.\n", |
157 | | - EMBEDDING_API_EXAMPLE, |
158 | | - ) |
159 | | - |
160 | | - is_encode_response_original = api.encode_response.__code__ is LitAPI.encode_response.__code__ |
161 | | - if not is_encode_response_original and inspect.isgeneratorfunction(api.encode_response): |
162 | | - raise ValueError( |
163 | | - "You are using yield in your encode_response method, which is used for streaming.", |
164 | | - "OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ", |
165 | | - "is not a sequential operation.", |
166 | | - "Please consider replacing yield with return in encode_response.\n", |
167 | | - EMBEDDING_API_EXAMPLE, |
168 | | - ) |
169 | | - |
170 | 158 | def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]: |
171 | 159 | return request.input |
172 | 160 |
|
|
0 commit comments