@@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
144144 0 ].embedding
145145 assert responses_float .data [1 ].embedding == responses_default .data [
146146 1 ].embedding
147+
148+
149+ @pytest .mark .asyncio
150+ @pytest .mark .parametrize (
151+ "model_name" ,
152+ [EMBEDDING_MODEL_NAME ],
153+ )
154+ async def test_single_embedding_truncation (
155+ embedding_client : openai .AsyncOpenAI , model_name : str ):
156+ input_texts = [
157+ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?" ,
158+ ]
159+
160+ # test single embedding
161+ embeddings = await embedding_client .embeddings .create (
162+ model = model_name ,
163+ input = input_texts ,
164+ extra_body = {"truncate_prompt_tokens" : 10 })
165+ assert embeddings .id is not None
166+ assert len (embeddings .data ) == 1
167+ assert len (embeddings .data [0 ].embedding ) == 4096
168+ assert embeddings .usage .completion_tokens == 0
169+ assert embeddings .usage .prompt_tokens == 10
170+ assert embeddings .usage .total_tokens == 10
171+
172+ input_tokens = [
173+ 1 , 24428 , 289 , 18341 , 26165 , 285 , 19323 , 283 , 289 , 26789 , 3871 , 28728 ,
174+ 9901 , 340 , 2229 , 385 , 340 , 315 , 28741 , 28804 , 2
175+ ]
176+ embeddings = await embedding_client .embeddings .create (
177+ model = model_name ,
178+ input = input_tokens ,
179+ extra_body = {"truncate_prompt_tokens" : 10 })
180+
181+ assert embeddings .id is not None
182+ assert len (embeddings .data ) == 1
183+ assert len (embeddings .data [0 ].embedding ) == 4096
184+ assert embeddings .usage .completion_tokens == 0
185+ assert embeddings .usage .prompt_tokens == 10
186+ assert embeddings .usage .total_tokens == 10
187+
188+
189+ @pytest .mark .asyncio
190+ @pytest .mark .parametrize (
191+ "model_name" ,
192+ [EMBEDDING_MODEL_NAME ],
193+ )
194+ async def test_single_embedding_truncation_invalid (
195+ embedding_client : openai .AsyncOpenAI , model_name : str ):
196+ input_texts = [
197+ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?" ,
198+ ]
199+
200+ with pytest .raises (openai .BadRequestError ):
201+ embeddings = await embedding_client .embeddings .create (
202+ model = model_name ,
203+ input = input_texts ,
204+ extra_body = {"truncate_prompt_tokens" : 8193 })
205+ assert "error" in embeddings .object
206+ assert "truncate_prompt_tokens value is greater than max_model_len. " \
207+ "Please, select a smaller truncation size." in embeddings .message
0 commit comments