Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
XINFERENCE_ENV_SSE_PING_ATTEMPTS_SECONDS = "XINFERENCE_SSE_PING_ATTEMPTS_SECONDS"
XINFERENCE_ENV_MAX_TOKENS = "XINFERENCE_MAX_TOKENS"
XINFERENCE_ENV_ALLOWED_IPS = "XINFERENCE_ALLOWED_IPS"
XINFERENCE_ENV_BATCH_SIZE = "XINFERENCE_BATCH_SIZE"
XINFERENCE_ENV_BATCH_TIMEOUT = "XINFERENCE_BATCH_TIMEOUT"


def get_xinference_home() -> str:
Expand Down Expand Up @@ -112,3 +114,5 @@ def get_xinference_home() -> str:
else None
)
XINFERENCE_ALLOWED_IPS = os.getenv(XINFERENCE_ENV_ALLOWED_IPS)
XINFERENCE_BATCH_SIZE = int(os.getenv(XINFERENCE_ENV_BATCH_SIZE, "32"))
XINFERENCE_BATCH_TIMEOUT = float(os.getenv(XINFERENCE_ENV_BATCH_TIMEOUT, "0.003"))
2 changes: 1 addition & 1 deletion xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def test_restful_api_for_embedding(setup):
assert len(embedding_res["data"][0]["embedding"]) == model_spec.dimensions
assert "model_replica" in embedding_res
assert embedding_res["model_replica"] is not None
assert embedding_res["model"] == payload["model"]
assert embedding_res["model"] == "unknown"

# test multiple
payload = {
Expand Down
12 changes: 10 additions & 2 deletions xinference/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def run_in_subprocess(
if parent_conn.poll(timeout=XINFERENCE_HEALTH_CHECK_TIMEOUT):
msg = parent_conn.recv()
if msg != READY:
raise RuntimeError(f"Start service process failed during startup:\n{msg}")
raise RuntimeError(
f"Start service process failed during startup:\n{msg}" # noqa: E231
)
else:
logger.info(
"No response from process after %s seconds", XINFERENCE_HEALTH_CHECK_TIMEOUT
Expand All @@ -157,7 +159,7 @@ def main(
# which will raise error after sub pool is created
multiprocessing.set_start_method("spawn")

supervisor_address = f"{host}:{get_next_port()}"
supervisor_address = f"{host}:{get_next_port()}" # noqa: E231
local_cluster = run_in_subprocess(
supervisor_address, metrics_exporter_host, metrics_exporter_port, logging_conf
)
Expand All @@ -181,3 +183,9 @@ def main(
)
finally:
local_cluster.kill()


if __name__ == "__main__":
from .cmdline import local

local()
103 changes: 103 additions & 0 deletions xinference/model/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022-2025 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import inspect
import logging
import types

from xoscar.batch import _ExtensibleWrapper

from ..constants import XINFERENCE_BATCH_SIZE, XINFERENCE_BATCH_TIMEOUT

logger = logging.getLogger(__name__)


class BatchMixin:
allow_batch = True
batch_size = XINFERENCE_BATCH_SIZE
batch_timeout = XINFERENCE_BATCH_TIMEOUT

def __init__(self, func: _ExtensibleWrapper):
self._queue: asyncio.Queue = asyncio.Queue()
self._func = func
self._func_name = func.func.__name__
setattr(self, self._func_name, types.MethodType(self._wrap_method(), self))

self._is_process_batch_running = False

def _ensure_process_batch_running(self):
if self._is_process_batch_running:
return

# create asyncio task to process batch
asyncio.create_task(self._process_batch())
self._is_process_batch_running = True

def _get_batch_size(self, *args, **kwargs) -> int:
raise NotImplementedError

async def _process_batch(self):
while True:
# Wait until at least one item is available
(first_args, first_kwargs), first_future = await self._queue.get()

delays = [self._func.delay(*first_args, **first_kwargs)]
size = self._get_batch_size(*first_args, **first_kwargs)
futures = [first_future]

# Try to gather more items into the same batch within a short timeout window
while size <= self.batch_size:
try:
# Wait for a new request for a short time window (e.g. 3ms)
# This allows batching multiple requests that arrive close in time.
(args, kwargs), future = await asyncio.wait_for(
self._queue.get(), timeout=self.batch_timeout
)
size += self._get_batch_size(*args, **kwargs)
delays.append(self._func.delay(*args, **kwargs))
futures.append(future)
except asyncio.TimeoutError:
# No new items arrived within the timeout window,
# stop collecting and start processing the current batch.
break

logger.debug("Calling batch %s with %d size", self._func_name, size)

try:
results = self._func.batch(*delays)
if inspect.isawaitable(results):
results = await results
except Exception as e: # Handle errors for the entire batch
for fut in futures:
fut.set_exception(e)
else:
# Ensure the number of results matches the number of input futures
assert len(results) == len(
futures
), f"#results should be equal to #futures, got {len(results)} and {len(futures)}"
# Deliver the results to the corresponding waiting callers
for fut, result in zip(futures, results):
fut.set_result(result)

def _wrap_method(self):

async def _replaced_async_method(model, *args, **kwargs):
self._ensure_process_batch_running()
loop = asyncio.get_running_loop()
fut = loop.create_future()
await self._queue.put(((args, kwargs), fut))
return await fut

return _replaced_async_method
107 changes: 106 additions & 1 deletion xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

import abc
import gc
import inspect
import logging
import os
from abc import abstractmethod
from collections import defaultdict
from typing import Annotated, Dict, List, Literal, Optional, Union

from xoscar import extensible

from ..._compat import ROOT_KEY, BaseModel, ErrorWrapper, Field, ValidationError
from ...device_utils import empty_cache
from ...types import Embedding
from ...utils import make_hashable
from ..core import VirtualEnvSettings
from ..utils import ModelInstanceInfoMixin
from .embed_family import match_embedding
Expand Down Expand Up @@ -240,7 +245,7 @@ def _text_length(text):
return sum([len(t) for t in text]) # Sum of length of individual strings

@abstractmethod
def create_embedding(
def _create_embedding(
self,
sentences: Union[str, List[str]],
**kwargs,
Expand All @@ -260,6 +265,106 @@ def create_embedding(
The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.
"""

@extensible
def create_embedding(
self,
sentences: Union[str, List[str]],
**kwargs,
):
return self._create_embedding(sentences, **kwargs)

@create_embedding.batch # type: ignore
def create_embedding(self, args_list, kwargs_list):
grouped = defaultdict(
lambda: {"sentences": [], "offsets": [], "kwargs": None, "indices": []}
)

# 1. Group by kwargs hash
for i, (args, kwargs) in enumerate(zip(args_list, kwargs_list)):
sentences, extra_kwargs = self._extract_sentences_kwargs(args, kwargs)
if isinstance(sentences, str):
sentences = [sentences]

key = make_hashable(extra_kwargs)
group = grouped[key]
group["kwargs"] = extra_kwargs

current_offset = len(group["sentences"])
group["offsets"].append((current_offset, len(sentences)))
group["sentences"].extend(sentences)
group["indices"].append(i) # remember original position

results_with_index = []

# 2. Process each group separately
for key, group in grouped.items():
sentences = group["sentences"]
kwargs = group["kwargs"]
offsets = group["offsets"]
indices = group["indices"]

embedding_list = self._create_embedding(sentences, **kwargs)
usage = {"total_tokens": len(sentences)}
model_uid = kwargs.get("model", "unknown")

# 3. Split and attach original index
for (offset, n), idx in zip(offsets, indices):
data = embedding_list["data"][offset : offset + n]
result = Embedding(
object="list",
model=model_uid,
model_replica=self._model_uid,
data=data,
usage=usage,
)
results_with_index.append((idx, result))

# 4. Sort by original call order
results_with_index.sort(key=lambda x: x[0])
results = [r for _, r in results_with_index]
return results

def _extract_sentences_kwargs(self, args, kwargs):
"""
Extract the 'sentences' argument and remaining kwargs from (*args, **kwargs)
for a given function.

This uses inspect.signature(func).bind_partial() to automatically match
both positional and keyword arguments, while handling bound methods
(functions with 'self' as the first parameter).

Args:
func: The target function whose parameters define how to bind args/kwargs.
args: The positional arguments passed to the function.
kwargs: The keyword arguments passed to the function.

Returns:
A tuple (sentences, extra_kwargs), where:
- sentences: The extracted 'sentences' argument (never None).
- extra_kwargs: Remaining keyword arguments excluding 'sentences'.

Raises:
KeyError: If 'sentences' argument is not found.
TypeError: If args/kwargs do not match the function signature.
"""
sig = inspect.signature(self._create_embedding)
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()

if "sentences" not in bound.arguments:
raise KeyError("'sentences' argument not found in args/kwargs")

sentences = bound.arguments["sentences"]
extra_kwargs = {k: v for k, v in kwargs.items() if k != "sentences"}
return sentences, extra_kwargs

def _get_batch_size(self, *args, **kwargs) -> int:
sentences = self._extract_sentences_kwargs(args, kwargs)[0]
if isinstance(sentences, list):
return len(sentences)
else:
return 1

def convert_ids_to_tokens(
self,
batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
Expand Down
24 changes: 15 additions & 9 deletions xinference/model/embedding/flag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
except ImportError:
flag_installed = False

from ....constants import XINFERENCE_BATCH_SIZE, XINFERENCE_BATCH_TIMEOUT
from ....device_utils import get_available_device
from ....types import Embedding, EmbeddingData, EmbeddingUsage
from ...batch import BatchMixin
from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1

FLAG_EMBEDDER_MODEL_LIST = support_native_bge_model_list() if flag_installed else []
logger = logging.getLogger(__name__)


class FlagEmbeddingModel(EmbeddingModel):
class FlagEmbeddingModel(EmbeddingModel, BatchMixin):
def __init__(
self,
model_uid: str,
Expand All @@ -47,15 +49,19 @@ def __init__(
return_sparse: bool = False,
**kwargs,
):
super().__init__(
model_uid,
model_path,
model_family,
quantization,
device,
**kwargs,
EmbeddingModel.__init__(
self, model_uid, model_path, model_family, quantization, device, **kwargs
)
BatchMixin.__init__(self, self.create_embedding) # type: ignore
self._return_sparse = return_sparse
if "batch_size" in kwargs:
self.batch_size = int(
self._kwargs.pop("batch_size") or XINFERENCE_BATCH_SIZE
)
if "batch_timeout" in kwargs:
self.batch_timeout = float(
self._kwargs.pop("batch_timeout") or XINFERENCE_BATCH_TIMEOUT
)

def load(self):
# add truncate_dim args hint
Expand Down Expand Up @@ -105,7 +111,7 @@ def load(self):
)
self._tokenizer = self._model.tokenizer

def create_embedding(
def _create_embedding(
self,
sentences: Union[str, List[str]],
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions xinference/model/embedding/flag/tests/test_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


# todo Refer to the return format of sentence_transformer
def test_embedding_model_with_flag():
async def test_embedding_model_with_flag():
model_path = None
try:
model_path = CacheManager(TEST_MODEL_SPEC).cache()
Expand All @@ -53,10 +53,10 @@ def test_embedding_model_with_flag():
input_text = "what is the capital of China?"

# test sparse and dense
r = model.create_embedding(input_text, **{"return_sparse": True})
r = await model.create_embedding(input_text, **{"return_sparse": True})
assert len(r["data"]) == 1

r = model.create_embedding(input_text)
r = await model.create_embedding(input_text)
assert len(r["data"][0]["embedding"]) == 384

# input is a lit
Expand All @@ -67,10 +67,10 @@ def test_embedding_model_with_flag():
"sorting algorithms",
]
# test sparse and dense
r = model.create_embedding(input_texts, **{"return_sparse": True})
r = await model.create_embedding(input_texts, **{"return_sparse": True})
assert len(r["data"]) == 4

r = model.create_embedding(input_texts)
r = await model.create_embedding(input_texts)
for d in r["data"]:
assert len(d["embedding"]) == 384
finally:
Expand Down
Loading
Loading