diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6cf1a6056f..4fdc92e028 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,11 @@
# ChangeLog
-## [0.7.22]
+## Unreleased
+
+### New Features
+- Added Xorbits inference for local deployments (#7151)
+
+## [0.7.22] - 2023-08-08
### New Features
- add ensemble retriever notebook (#7190)
diff --git a/docs/core_modules/model_modules/llms/modules.md b/docs/core_modules/model_modules/llms/modules.md
index 658c5db7dc..8ecd7a3ccf 100644
--- a/docs/core_modules/model_modules/llms/modules.md
+++ b/docs/core_modules/model_modules/llms/modules.md
@@ -81,3 +81,11 @@ maxdepth: 1
---
/examples/llm/llama_api.ipynb
```
+
+## Xorbits Inference
+```{toctree}
+---
+maxdepth: 1
+---
+/examples/llm/XinferenceLocalDeployment.ipynb
+```
diff --git a/docs/examples/llm/XinferenceLocalDeployment.ipynb b/docs/examples/llm/XinferenceLocalDeployment.ipynb
new file mode 100644
index 0000000000..846216d052
--- /dev/null
+++ b/docs/examples/llm/XinferenceLocalDeployment.ipynb
@@ -0,0 +1,209 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "7096589b-daaf-440a-b89d-b4956f2db4b2",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Using Xorbits Inference to Deploy Local LLMs - in 3 steps!\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "d8cfbe6f-4c50-4c4f-90f9-03bb91201ef5",
+ "metadata": {},
+ "source": [
+ "## 🤖 Installing and Running Xorbits Inference (1/3)\n",
+ "\n",
+ "#### i. Run `pip install \"xinference[all]\"` in a terminal window\n",
+ "\n",
+ "#### ii. After installation is complete, restart this jupyter notebook\n",
+ "\n",
+ "#### iii. Run `xinference` in a new terminal window\n",
+ "\n",
+ "#### iv. You should see something similar to the following output:\n",
+ "\n",
+ "```\n",
+ "INFO:xinference:Xinference successfully started. Endpoint: http://127.0.0.1:9997\n",
+ "INFO:xinference.core.service:Worker 127.0.0.1:21561 has been added successfully\n",
+ "INFO:xinference.deploy.worker:Xinference worker successfully started.\n",
+ "```\n",
+ "\n",
+ "#### v. In the endpoint description, locate the endpoint port number after the colon. In the above case it is `9997`\n",
+ "\n",
+ "#### vi. Paste the endpoint port number in the following cell"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5d520d56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "port = 9997 # replace with your endpoint port number"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "93139076",
+ "metadata": {},
+ "source": [
+ "## 🚀 Downloading and Launching Local Models (2/3)\n",
+ "\n",
+ "#### In this step, simply run the following code blocks\n",
+ "\n",
+ "#### Also, feel free to change the model configuration for different experiences!\n",
+ "\n",
+ "#### The latest list of supported models can be found in Xorbits Inference's [official GitHub page](https://github.com/xorbitsai/inference/blob/main/README.md)\n",
+ "\n",
+ "##### Here are the parameter options for vicuna-v1.3, ranked from the least space-consuming to the most resource-intensive but high-performing:\n",
+ "\n",
+ "model_size_in_billions: `7`, `13`, `33`\n",
+ "\n",
+ "quantization: `q2_K`, `q3_K_L`, `q3_K_M`, `q3_K_S`, `q4_0`, `q4_1`, `q4_K_M`, `q4_K_S`, `q5_0`, `q5_1`, `q5_K_M`, `q5_K_S`, `q6_K`, `q8_0`\n",
+ "\n",
+ "##### Here are a few of the supported models:\n",
+ "\n",
+ "| Name | Type | Language | Format | Size (in billions) | Quantization |\n",
+ "|---------------|------------------|----------|---------|--------------------|-----------------------------------------|\n",
+ "| baichuan | Foundation Model | en, zh | ggmlv3 | 7 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |\n",
+ "| llama-2-chat | RLHF Model | en | ggmlv3 | 7, 13, 70 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |\n",
+ "| chatglm | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |\n",
+ "| chatglm2 | SFT Model | en, zh | ggmlv3 | 6 | 'q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0' |\n",
+ "| wizardlm-v1.0 | SFT Model | en | ggmlv3 | 7, 13, 33 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |\n",
+ "| wizardlm-v1.1 | SFT Model | en | ggmlv3 | 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |\n",
+ "| vicuna-v1.3 | SFT Model | en | ggmlv3 | 7, 13 | 'q2_K', 'q3_K_L', ... , 'q6_K', 'q8_0' |\n",
+ "\n",
+ "\n",
+ "In order to achieve satisfactory results, it is recommended to use models above 13 billion in size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fd1d259c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# If Xinference can not be imported, you may need to restart jupyter notebook\n",
+ "from llama_index import (\n",
+ " ListIndex,\n",
+ " TreeIndex,\n",
+ " VectorStoreIndex,\n",
+ " KeywordTableIndex,\n",
+ " KnowledgeGraphIndex,\n",
+ " SimpleDirectoryReader,\n",
+ " ServiceContext,\n",
+ ")\n",
+ "from llama_index.llms import Xinference\n",
+ "from xinference.client import RESTfulClient\n",
+ "from IPython.display import Markdown, display"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b48c6d7a-7a38-440b-8ecb-f43f9050ee54",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define a client to send commands to xinference\n",
+ "client = RESTfulClient(f\"http://localhost:{port}\")\n",
+ "\n",
+ "# Download and Launch a model, this may take a while the first time\n",
+ "model_uid = client.launch_model(\n",
+ " model_name=\"llama-2-chat\",\n",
+ " model_size_in_billions=7,\n",
+ " model_format=\"ggmlv3\",\n",
+ " quantization=\"q2_K\",\n",
+ " n_ctx=4096,\n",
+ ")\n",
+ "\n",
+ "llm = Xinference(endpoint=f\"http://localhost:{port}\", model_uid=model_uid)\n",
+ "service_context = ServiceContext.from_defaults(llm=llm)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "094a02b7",
+ "metadata": {},
+ "source": [
+ "## 🕺 Index the Data and Start Chatting! (3/3)\n",
+ "\n",
+ "#### In this step, simply run the following code blocks\n",
+ "\n",
+ "#### Also, feel free to change the index that is used for different experiences\n",
+ "\n",
+ "#### A list of all available indexes can be found in Llama Index's [official Docs](https://gpt-index.readthedocs.io/en/latest/core_modules/data_modules/index/modules.html)\n",
+ "\n",
+ "Here are some available indexes that are imported:\n",
+ "\n",
+ "`ListIndex`, `TreeIndex`, `VetorStoreIndex`, `KeywordTableIndex`, `KnowledgeGraphIndex`\n",
+ "\n",
+ "The following code uses `VetorStoreIndex`. To change index, simply replace its name with another index"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "708b323e-d314-4b83-864b-22a1ead60de9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create index from the data\n",
+ "documents = SimpleDirectoryReader(\"../data/paul_graham\").load_data()\n",
+ "\n",
+ "# change index name in the following line\n",
+ "index = VectorStoreIndex.from_documents(\n",
+ " documents=documents, service_context=service_context\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2c2de13b-133f-404e-9661-2acafcdf2573",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# ask a question and display the answer\n",
+ "query_engine = index.as_query_engine()\n",
+ "\n",
+ "question = \"What did the author do after his time at Y Combinator?\"\n",
+ "\n",
+ "response = query_engine.query(question)\n",
+ "display(Markdown(f\"{response}\"))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py
index e1cca2bb28..3dce9f036c 100644
--- a/llama_index/llms/__init__.py
+++ b/llama_index/llms/__init__.py
@@ -19,6 +19,7 @@
from llama_index.llms.palm import PaLM
from llama_index.llms.predibase import PredibaseLLM
from llama_index.llms.replicate import Replicate
+from llama_index.llms.xinference import Xinference
__all__ = [
"OpenAI",
@@ -40,4 +41,5 @@
"CompletionResponseGen",
"CompletionResponseAsyncGen",
"LLMMetadata",
+ "Xinference",
]
diff --git a/llama_index/llms/xinference.py b/llama_index/llms/xinference.py
new file mode 100644
index 0000000000..e15d6255a4
--- /dev/null
+++ b/llama_index/llms/xinference.py
@@ -0,0 +1,186 @@
+from typing import Any, Dict, Optional, Sequence
+
+from llama_index.callbacks import CallbackManager
+from llama_index.constants import DEFAULT_NUM_OUTPUTS
+from llama_index.llms.base import (
+ ChatMessage,
+ ChatResponse,
+ ChatResponseGen,
+ CompletionResponse,
+ CompletionResponseGen,
+ LLMMetadata,
+ MessageRole,
+ llm_chat_callback,
+ llm_completion_callback,
+)
+from llama_index.llms.custom import CustomLLM
+from llama_index.llms.xinference_utils import (
+ xinference_message_to_history,
+ xinference_modelname_to_contextsize,
+)
+
+# an approximation of the ratio between llama and GPT2 tokens
+TOKEN_RATIO = 2.5
+
+
+class Xinference(CustomLLM):
+ def __init__(
+ self,
+ model_uid: str,
+ endpoint: str,
+ temperature: float = 1.0,
+ callback_manager: Optional[CallbackManager] = None,
+ ) -> None:
+ self.temperature = temperature
+ self.model_uid = model_uid
+ self.endpoint = endpoint
+ self.callback_manager = callback_manager or CallbackManager([])
+
+ self._model_description = None
+ self._context_window = None
+ self._generator = None
+ self._client = None
+ self._model = None
+ self.load()
+
+ def load(self) -> None:
+ try:
+ from xinference.client import RESTfulClient
+ except ImportError:
+ raise ImportError(
+ "Could not import Xinference library."
+ 'Please install Xinference with `pip install "xinference[all]"`'
+ )
+
+ self._client = RESTfulClient(self.endpoint)
+
+ try:
+ assert isinstance(self._client, RESTfulClient)
+ except AssertionError:
+ raise RuntimeError(
+ "Could not create RESTfulClient instance."
+ "Please make sure Xinference endpoint is running at the correct port."
+ )
+
+ self._generator = self._client.get_model(self.model_uid)
+ self._model_description = self._client.list_models()[self.model_uid]
+
+ try:
+ assert self._generator is not None
+ assert self._model_description is not None
+ except AssertionError:
+ raise RuntimeError(
+ "Could not get model from endpoint."
+ "Please make sure Xinference endpoint is running at the correct port."
+ )
+
+ self._model = self._model_description["model_name"]
+ self._context_window = xinference_modelname_to_contextsize(self._model)
+
+ @property
+ def metadata(self) -> LLMMetadata:
+ """LLM metadata."""
+ assert isinstance(self._context_window, int)
+ return LLMMetadata(
+ context_window=int(self._context_window // TOKEN_RATIO),
+ num_output=DEFAULT_NUM_OUTPUTS,
+ model_name=self._model,
+ )
+
+ @property
+ def _model_kwargs(self) -> Dict[str, Any]:
+ assert self._context_window is not None
+ base_kwargs = {
+ "temperature": self.temperature,
+ "max_length": self._context_window,
+ }
+ model_kwargs = {
+ **base_kwargs,
+ **self._model_description,
+ }
+ return model_kwargs
+
+ def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
+ return {"prompt": prompt, **self._model_kwargs, **kwargs}
+
+ @llm_chat_callback()
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
+ assert self._generator is not None
+ prompt = messages[-1].content if len(messages) > 0 else ""
+ history = [xinference_message_to_history(message) for message in messages[:-1]]
+ response_text = self._generator.chat(
+ prompt=prompt,
+ chat_history=history,
+ generate_config={"stream": False, "temperature": self.temperature},
+ )["choices"][0]["message"]["content"]
+ response = ChatResponse(
+ message=ChatMessage(
+ role=MessageRole.ASSISTANT,
+ content=response_text,
+ ),
+ delta=None,
+ )
+ return response
+
+ @llm_chat_callback()
+ def stream_chat(
+ self, messages: Sequence[ChatMessage], **kwargs: Any
+ ) -> ChatResponseGen:
+ assert self._generator is not None
+ prompt = messages[-1].content if len(messages) > 0 else ""
+ history = [xinference_message_to_history(message) for message in messages[:-1]]
+ response_iter = self._generator.chat(
+ prompt=prompt,
+ chat_history=history,
+ generate_config={"stream": True, "temperature": self.temperature},
+ )
+
+ def gen() -> None:
+ text = ""
+ for c in response_iter:
+ delta = c["choices"][0]["delta"].get("content", "")
+ text += delta
+ yield ChatResponse(
+ message=ChatMessage(
+ role=MessageRole.ASSISTANT,
+ content=text,
+ ),
+ delta=delta,
+ )
+
+ return gen()
+
+ @llm_completion_callback()
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
+ assert self._generator is not None
+ response_text = self._generator.chat(
+ prompt=prompt,
+ chat_history=None,
+ generate_config={"stream": False, "temperature": self.temperature},
+ )["choices"][0]["message"]["content"]
+ response = CompletionResponse(
+ delta=None,
+ text=response_text,
+ )
+ return response
+
+ @llm_completion_callback()
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
+ assert self._generator is not None
+ response_iter = self._generator.chat(
+ prompt=prompt,
+ chat_history=None,
+ generate_config={"stream": True, "temperature": self.temperature},
+ )
+
+ def gen() -> CompletionResponseGen:
+ text = ""
+ for c in response_iter:
+ delta = c["choices"][0]["delta"].get("content", "")
+ text += delta
+ yield CompletionResponse(
+ delta=delta,
+ text=text,
+ )
+
+ return gen()
diff --git a/llama_index/llms/xinference_utils.py b/llama_index/llms/xinference_utils.py
new file mode 100644
index 0000000000..7c682d5bea
--- /dev/null
+++ b/llama_index/llms/xinference_utils.py
@@ -0,0 +1,38 @@
+from typing import Optional
+from typing_extensions import NotRequired, TypedDict
+
+from llama_index.llms.base import ChatMessage
+
+XINFERENCE_MODEL_SIZES = {
+ "baichuan": 2048,
+ "baichuan-chat": 2048,
+ "wizardlm-v1.0": 2048,
+ "vicuna-v1.3": 2048,
+ "orca": 2048,
+ "chatglm": 2048,
+ "chatglm2": 8192,
+ "llama-2-chat": 4096,
+ "llama-2": 4096,
+}
+
+
+class ChatCompletionMessage(TypedDict):
+ role: str
+ content: Optional[str]
+ user: NotRequired[str]
+
+
+def xinference_message_to_history(message: ChatMessage) -> ChatCompletionMessage:
+ return ChatCompletionMessage(role=message.role, content=message.content)
+
+
+def xinference_modelname_to_contextsize(modelname: str) -> int:
+ context_size = XINFERENCE_MODEL_SIZES.get(modelname, None)
+
+ if context_size is None:
+ raise ValueError(
+ f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
+ "Known models are: " + ", ".join(XINFERENCE_MODEL_SIZES.keys())
+ )
+
+ return context_size
diff --git a/tests/llms/test_xinference.py b/tests/llms/test_xinference.py
new file mode 100644
index 0000000000..d4315ed70a
--- /dev/null
+++ b/tests/llms/test_xinference.py
@@ -0,0 +1,176 @@
+from typing import List, Dict, Any, Union, Iterator, Generator, Mapping, Sequence
+
+import pytest
+from llama_index.llms.base import (
+ ChatMessage,
+ ChatResponse,
+ MessageRole,
+ CompletionResponse,
+)
+from llama_index.llms.xinference import Xinference
+
+mock_chat_history: List[ChatMessage] = [
+ ChatMessage(
+ role=MessageRole.USER,
+ message="mock_chat_history_0",
+ ),
+ ChatMessage(
+ role=MessageRole.ASSISTANT,
+ message="mock_chat_history_1",
+ ),
+ ChatMessage(
+ role=MessageRole.USER,
+ message="mock_chat_history_2",
+ ),
+]
+
+mock_chat: Dict[str, Any] = {
+ "id": "test_id",
+ "object": "chat.completion",
+ "created": 0,
+ "model": "test_model",
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": "test_response"},
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+}
+
+mock_chat_stream: List[Dict[str, Any]] = [
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": 1,
+ "object": "chat.completion.chunk",
+ "choices": [
+ {"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}
+ ],
+ },
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": 1,
+ "object": "chat.completion.chunk",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": "test_response_stream"},
+ "finish_reason": None,
+ }
+ ],
+ },
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": 1,
+ "object": "chat.completion.chunk",
+ "choices": [{"index": 0, "delta": {"content": " "}, "finish_reason": "length"}],
+ },
+]
+
+
+def mock_chat_stream_iterator() -> Generator:
+ for i in mock_chat_stream:
+ yield i
+
+
+class MockXinferenceModel:
+ def chat(
+ self,
+ prompt: str,
+ chat_history: List[Mapping[str, Any]],
+ generate_config: Dict[str, Any],
+ ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
+ assert isinstance(prompt, str)
+ if chat_history is not None:
+ for chat_item in chat_history:
+ assert "role" in chat_item
+ assert isinstance(chat_item["role"], str)
+ assert "content" in chat_item
+ assert isinstance(chat_item["content"], str)
+
+ if "stream" in generate_config and generate_config["stream"] is True:
+ return mock_chat_stream_iterator()
+ else:
+ return mock_chat
+
+
+class MockRESTfulClient:
+ def get_model(self) -> MockXinferenceModel:
+ return MockXinferenceModel()
+
+
+class MockXinference(Xinference):
+ def load(self) -> None:
+ self._client = MockRESTfulClient() # type: ignore[assignment]
+
+ assert self._client is not None
+ self._generator = self._client.get_model()
+
+
+def test_init() -> None:
+ dummy = MockXinference(
+ model_uid="uid",
+ endpoint="endpoint",
+ )
+ assert dummy.model_uid == "uid"
+ assert dummy.endpoint == "endpoint"
+ assert isinstance(dummy._client, MockRESTfulClient)
+
+
+@pytest.mark.parametrize("chat_history", [mock_chat_history, tuple(mock_chat_history)])
+def test_chat(chat_history: Sequence[ChatMessage]) -> None:
+ dummy = MockXinference("uid", "endpoint")
+ response = dummy.chat(chat_history)
+ assert isinstance(response, ChatResponse)
+ assert response.delta is None
+ assert response.message.role == MessageRole.ASSISTANT
+ assert response.message.content == "test_response"
+
+
+@pytest.mark.parametrize("chat_history", [mock_chat_history, tuple(mock_chat_history)])
+def test_stream_chat(chat_history: Sequence[ChatMessage]) -> None:
+ dummy = MockXinference("uid", "endpoint")
+ response_gen = dummy.stream_chat(chat_history)
+ total_text = ""
+ for i, res in enumerate(response_gen):
+ assert i < len(mock_chat_stream)
+ assert isinstance(res, ChatResponse)
+ assert isinstance(mock_chat_stream[i]["choices"], List)
+ assert isinstance(mock_chat_stream[i]["choices"][0], Dict)
+ assert isinstance(mock_chat_stream[i]["choices"][0]["delta"], Dict)
+ assert res.delta == mock_chat_stream[i]["choices"][0]["delta"].get(
+ "content", ""
+ )
+ assert res.message.role == MessageRole.ASSISTANT
+
+ total_text += mock_chat_stream[i]["choices"][0]["delta"].get("content", "")
+ assert total_text == res.message.content
+
+
+def test_complete() -> None:
+ messages = "test_input"
+ dummy = MockXinference("uid", "endpoint")
+ response = dummy.complete(messages)
+ assert isinstance(response, CompletionResponse)
+ assert response.delta is None
+ assert response.text == "test_response"
+
+
+def test_stream_complete() -> None:
+ message = "test_input"
+ dummy = MockXinference("uid", "endpoint")
+ response_gen = dummy.stream_complete(message)
+ total_text = ""
+ for i, res in enumerate(response_gen):
+ assert i < len(mock_chat_stream)
+ assert isinstance(res, CompletionResponse)
+ assert res.delta == mock_chat_stream[i]["choices"][0]["delta"].get(
+ "content", ""
+ )
+
+ total_text += mock_chat_stream[i]["choices"][0]["delta"].get("content", "")
+ assert total_text == res.text