-
Notifications
You must be signed in to change notification settings - Fork 24
Added embedding service #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
|
|
||
|
|
||
| import asyncio | ||
| import gc | ||
| import logging | ||
| from typing import Any, AsyncIterator, Dict, List, Union | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import Tensor | ||
|
|
||
| from transformers import AutoTokenizer | ||
| from optimum.intel import OVModelForFeatureExtraction | ||
|
|
||
| from src.server.models.optimum import TokenizerConfig | ||
|
|
||
| from typing import Any, AsyncIterator, Dict, Optional | ||
|
|
||
| from src.server.model_registry import ModelLoadConfig, ModelRegistry | ||
|
|
||
|
|
||
|
|
||
|
|
||
| class Optimum_EMB: | ||
|
|
||
| def __init__(self, load_config: ModelLoadConfig): | ||
| self.model_path = None | ||
| self.encoder_tokenizer = None | ||
| self.load_config = load_config | ||
|
|
||
| def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | ||
| left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] | ||
| if left_padding: | ||
| return last_hidden_states[:, -1] | ||
| else: | ||
| sequence_lengths = attention_mask.sum(dim=1) - 1 | ||
| batch_size = last_hidden_states.shape[0] | ||
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] | ||
|
|
||
| def generate_type(self, tok_config: TokenizerConfig): | ||
| """ | ||
| Unified text generation method that routes to streaming or non-streaming | ||
| based on the stream flag in gen_config. Both paths return an async iterator. | ||
|
|
||
| Args: | ||
| gen_config: Configuration containing the stream flag and other parameters | ||
|
|
||
| Returns: | ||
| - Non-streaming: async iterator yielding [metrics: dict, new_text: str] | ||
| - Streaming: async iterator yielding token chunks (str)... then [metrics: dict, new_text: str] | ||
| """ | ||
| return self.generate_embeddings(tok_config) | ||
|
|
||
| def prepare_inputs(): | ||
| pass | ||
|
|
||
| async def generate_embeddings(self, tok_config: TokenizerConfig) -> AsyncIterator[Union[Dict[str, Any], str]]: | ||
|
|
||
| # Tokenize the input texts | ||
| batch_dict = self.encoder_tokenizer( | ||
|
||
| text=tok_config.text, | ||
| text_pair=tok_config.text_pair, | ||
| text_target=tok_config.text_target, | ||
| text_pair_target=tok_config.text_pair_target, | ||
| add_special_tokens=tok_config.add_special_tokens, | ||
| padding=tok_config.padding, | ||
| truncation=tok_config.truncation, | ||
| max_length=tok_config.max_length, | ||
| stride=tok_config.stride, | ||
| is_split_into_words=tok_config.is_split_into_words, | ||
| pad_to_multiple_of=tok_config.pad_to_multiple_of, | ||
| padding_side=tok_config.padding_side, | ||
| return_tensors=tok_config.return_tensors, | ||
| return_token_type_ids=tok_config.return_token_type_ids, | ||
| return_attention_mask=tok_config.return_attention_mask, | ||
| return_overflowing_tokens=tok_config.return_overflowing_tokens, | ||
| return_special_tokens_mask=tok_config.return_special_tokens_mask, | ||
| return_offsets_mapping=tok_config.return_offsets_mapping, | ||
| return_length=tok_config.return_length, | ||
| verbose=tok_config.verbose | ||
| ) | ||
| batch_dict.to(self.model.device) | ||
| outputs = self.model(**batch_dict) | ||
| embeddings = self.last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) | ||
| # normalize embeddings | ||
| if tok_config.return_tensors=="pt": | ||
| embeddings = F.normalize(embeddings, p=2, dim=1) | ||
| yield embeddings.tolist() | ||
|
|
||
| async def generate_stream(): | ||
SearchSavior marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pass | ||
|
|
||
| def collect_metrics(self, tok_config: TokenizerConfig, perf_metrics) -> Dict[str, Any]: | ||
SearchSavior marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pass | ||
|
|
||
| def load_model(self, loader: ModelLoadConfig): | ||
| """Load model using a ModelLoadConfig configuration and cache the tokenizer. | ||
|
|
||
| Args: | ||
| loader: ModelLoadConfig containing model_path, device, engine, and runtime_config. | ||
| """ | ||
|
|
||
| self.model = OVModelForFeatureExtraction.from_pretrained(loader.model_path, | ||
SearchSavior marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| device=loader.device, | ||
| export=False) | ||
|
|
||
| self.encoder_tokenizer = AutoTokenizer.from_pretrained(loader.model_path) | ||
SearchSavior marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| logging.info(f"Model loaded successfully: {loader.model_name}") | ||
|
|
||
| async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool: | ||
| """Unregister model from registry and free memory resources. | ||
|
|
||
| Args: | ||
| registry: ModelRegistry to unregister from | ||
| model_id: Private model identifier returned by register_load | ||
|
|
||
| Returns: | ||
| True if the model was found and unregistered, else False. | ||
| """ | ||
| removed = await registry.register_unload(model_name) | ||
|
|
||
| if self.model is not None: | ||
| del self.model | ||
| self.model = None | ||
|
|
||
| if self.encoder_tokenizer is not None: | ||
| del self.encoder_tokenizer | ||
| self.encoder_tokenizer = None | ||
|
|
||
| gc.collect() | ||
| logging.info(f"[{self.load_config.model_name}] weights and tokenizer unloaded and memory cleaned up") | ||
| return removed | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,7 @@ | |
| import time | ||
| import uuid | ||
| import traceback | ||
| from typing import Any, AsyncIterator, List, Optional, Dict | ||
| from typing import Any, AsyncIterator, List, Optional, Dict, Union | ||
|
|
||
| from pydantic import BaseModel | ||
| from fastapi import Depends, FastAPI, HTTPException, Request | ||
|
|
@@ -21,6 +21,7 @@ | |
| from src.server.worker_registry import WorkerRegistry | ||
| from src.server.models.openvino import OV_KokoroGenConfig | ||
| from src.server.models.ov_genai import OVGenAI_GenConfig, OVGenAI_WhisperGenConfig | ||
| from src.server.models.optimum import TokenizerConfig | ||
|
|
||
| #===============================================================# | ||
| # Logging | ||
|
|
@@ -151,6 +152,15 @@ class OpenAIKokoroRequest(BaseModel): | |
| language: Optional[str] = None | ||
| response_format: Optional[str] = "wav" | ||
|
|
||
| # https://platform.openai.com/docs/api-reference/embeddings | ||
| class EmbeddingsRequest(BaseModel): | ||
| model: str | ||
| input: Union[str, List[str], List[List[str]]] | ||
| dimensions: Optional[int] = None | ||
| encoding_format: Optional[str] = "float" #not implemented | ||
| user: Optional[str] = None, #not implemented | ||
| #end of openai api | ||
| config: Optional[TokenizerConfig] = None | ||
|
|
||
| @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) | ||
| async def openai_list_models(): | ||
|
|
@@ -336,3 +346,53 @@ async def openai_audio_speech(request: OpenAIKokoroRequest): | |
| except Exception as exc: | ||
| raise HTTPException(status_code=500, detail=f"Speech synthesis failed: {str(exc)}") | ||
|
|
||
| @app.post("/v1/embeddings", dependencies=[Depends(verify_api_key)]) | ||
| async def embeddings(request: EmbeddingsRequest): | ||
|
|
||
| try: | ||
|
|
||
| tok_config = TokenizerConfig( | ||
|
||
| text=request.input | ||
| ) | ||
|
|
||
| if request.config: | ||
| tok_config = request.config | ||
| if not tok_config.text: | ||
| tok_config.text = request.input | ||
|
|
||
| if not tok_config.max_length and request.dimensions>0: | ||
| tok_config.max_length = request.dimensions | ||
|
|
||
| model_name = request.model | ||
| created_ts = int(time.time()) | ||
| request_id = f"ov-{uuid.uuid4().hex[:24]}" | ||
|
|
||
| result = await _workers.embed(model_name, tok_config) | ||
| data = result.get("data", None) | ||
| metrics = result.get("metrics", {}) or {} | ||
|
|
||
| prompt_tokens = metrics.get("input_token", 0) | ||
| total_tokens = metrics.get("total_token", prompt_tokens) | ||
|
|
||
| response = { | ||
| "id": request_id, | ||
| "object": "list", | ||
| "created": created_ts, | ||
| "model": model_name, | ||
| "data": [ | ||
| { | ||
| "index": 0, | ||
| "object": "embedding", | ||
| "embedding": data | ||
| } | ||
| ], | ||
| "usage": { | ||
| "prompt_tokens": prompt_tokens, | ||
| "total_tokens": total_tokens, | ||
| }, | ||
| } | ||
| return response | ||
| except ValueError as exc: | ||
| raise HTTPException(status_code=400, detail=str(exc)) | ||
| except Exception as exc: | ||
| raise HTTPException(status_code=500, detail=f"Embedding failed: {str(exc)}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_type is used to keep api abstraction intact for stream bool behavior. siicne we dont stream you can probably call generate_embeddings directly, unless you think this effects task coverage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll short circuit it. I was just hacking my way through and trying not to deviate too much from a working example until I figured out which direction was up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the infer_emb function is already pointing to generate_embeddings. I'll just remove generate_type.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we make it pooling_type and maybe add other pooling strategies later? Was just reading about mean pooling vs last token pooling and it sounds interesting, possibly easy to implement. I don't fully understand how it works yet. Maybe use pooling_type as a knob like generate_type to keep the async serving abstraction intact. so we configure pooling_type from a request header, since the generate_embeddings call probably wont change much, we make it easy to tinker with more techniques in the future? What do you think