Skip to content
189 changes: 189 additions & 0 deletions examples/paul_graham_essay/XinferenceLocalDeployment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7096589b-daaf-440a-b89d-b4956f2db4b2",
"metadata": {
"tags": []
},
"source": [
"# Using Xorbits Inference to Deploy Local LLMs - in 3 steps!\n"
]
},
{
"cell_type": "markdown",
"id": "d8cfbe6f-4c50-4c4f-90f9-03bb91201ef5",
"metadata": {},
"source": [
"## <span style=\"font-size: xx-large;;\">🤖 </span> Installing and Running Xorbits Inference (1/3)\n",
"\n",
"#### i. Run `pip install \"xinference[all]\"` in a terminal window\n",
"\n",
"#### ii. After installation is complete, restart this jupyter notebook\n",
"\n",
"#### iii. Run `xinference` in a new terminal window\n",
"\n",
"#### iv. You should see something similar to the following output:\n",
"\n",
"```\n",
"INFO:xinference:Xinference successfully started. Endpoint: http://127.0.0.1:9997\n",
"INFO:xinference.core.service:Worker 127.0.0.1:21561 has been added successfully\n",
"INFO:xinference.deploy.worker:Xinference worker successfully started.\n",
"```\n",
"\n",
"#### v. In the endpoint description, locate the endpoint port number after the colon. In the above case it is `9997`\n",
"\n",
"#### vi. Paste the endpoint port number in the following cell"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1217d40",
"metadata": {},
"outputs": [],
"source": [
"port = 9997 # replace with your endpoint port number"
]
},
{
"cell_type": "markdown",
"id": "c96ea31e",
"metadata": {},
"source": [
"## <span style=\"font-size: xx-large;;\">🚀 </span> Downloading and Launching Local Models (2/3)\n",
"\n",
"#### In this step, simply run the following code blocks\n",
"\n",
"#### Also, feel free to change the model size and quantization for different experiences!\n",
"\n",
"#### A complete list of supported models can be found in Xorbits Inference's [official GitHub page](https://github.com/xorbitsai/inference/blob/main/README.md)\n",
"\n",
"Here are the parameter options for vicuna-v1.3, ranked from the quickest and least space-consuming to the most resource-intensive but high-performing.\n",
"\n",
"model_size_in_billions: 7, 13, 33\n",
"\n",
"quantization: \"q2_K\", \"q3_K_L\", \"q3_K_M\", \"q3_K_S\", \"q4_0\", \"q4_1\", \"q4_K_M\", \"q4_K_S\", \"q5_0\", \"q5_1\", \"q5_K_M\", \"q5_K_S\", \"q6_K\", \"q8_0\"\n",
"\n",
"In order to achieve satisfactory results, it is recommended to use models above 13 billion in size"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7248be01",
"metadata": {},
"outputs": [],
"source": [
"# If Xinference can not be imported, you may need to restart jupyter notebook\n",
"from llama_index import (\n",
" ListIndex,\n",
" TreeIndex,\n",
" VectorStoreIndex,\n",
" KeywordTableIndex,\n",
" KnowledgeGraphIndex,\n",
" SimpleDirectoryReader,\n",
" ServiceContext,\n",
")\n",
"from llama_index.llms import Xinference\n",
"from xinference.client import RESTfulClient\n",
"from IPython.display import Markdown, display"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b48c6d7a-7a38-440b-8ecb-f43f9050ee54",
"metadata": {},
"outputs": [],
"source": [
"# Define a client to send commands to xinference\n",
"client = RESTfulClient(f\"http://localhost:{port}\")\n",
"\n",
"# Download and Launch a model, this may take a while the first time\n",
"model_uid = client.launch_model(\n",
" model_name = \"vicuna-v1.3\",\n",
" model_size_in_billions = 7,\n",
" model_format = \"ggmlv3\",\n",
" quantization = \"q2_K\",\n",
")\n",
"\n",
"llm = Xinference(endpoint = f\"http://localhost:{port}\", model_uid = model_uid)\n",
"service_context = ServiceContext.from_defaults(llm=llm)"
]
},
{
"cell_type": "markdown",
"id": "510ba348",
"metadata": {},
"source": [
"## <span style=\"font-size: xx-large;;\">🕺 </span> Index the Data and Start Chatting! (3/3)\n",
"\n",
"#### In this step, simply run the following code blocks\n",
"\n",
"#### Also, feel free to change the index that is used for different experiences\n",
"\n",
"#### A list of all available indexes can be found in Llama Index's [official Docs](https://gpt-index.readthedocs.io/en/latest/core_modules/data_modules/index/modules.html)\n",
"\n",
"Here are some available indexes that are imported:\n",
"\n",
"`ListIndex`, `TreeIndex`, `VetorStoreIndex`, `KeywordTableIndex`, `KnowledgeGraphIndex`\n",
"\n",
"The following code uses `VetorStoreIndex`. To change index, simply replace its name with another index"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "708b323e-d314-4b83-864b-22a1ead60de9",
"metadata": {},
"outputs": [],
"source": [
"# create index from the data\n",
"documents = SimpleDirectoryReader(\"../paul_graham_essay/data\").load_data()\n",
"\n",
"# change index name in the following line\n",
"index = VectorStoreIndex.from_documents(\n",
" documents=documents, service_context=service_context\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c2de13b-133f-404e-9661-2acafcdf2573",
"metadata": {},
"outputs": [],
"source": [
"# ask a question and display the answer\n",
"query_engine = index.as_query_engine()\n",
"\n",
"question = \"What did the author do after his time at Y Combinator?\"\n",
"\n",
"response = query_engine.query(question)\n",
"display(Markdown(f\"<b>{response}</b>\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions llama_index/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from llama_index.llms.palm import PaLM
from llama_index.llms.predibase import PredibaseLLM
from llama_index.llms.replicate import Replicate
from llama_index.llms.xinference import Xinference

__all__ = [
"OpenAI",
Expand All @@ -40,4 +41,5 @@
"CompletionResponseGen",
"CompletionResponseAsyncGen",
"LLMMetadata",
"Xinference",
]
153 changes: 153 additions & 0 deletions llama_index/llms/xinference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Any, Dict, Sequence

from llama_index.constants import DEFAULT_NUM_OUTPUTS
from llama_index.llms.base import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.llms.custom import CustomLLM
from llama_index.llms.xinference_utils import (
xinference_message_to_history,
xinference_modelname_to_contextsize,
)

# an approximation of the ratio between llama and GPT2 tokens
TOKEN_RATIO = 2.5


class Xinference(CustomLLM):
def __init__(
self,
model_uid: str,
endpoint: str,
temperature: float = 1.0,
) -> None:
self.temperature = temperature
self.model_uid = model_uid
self.endpoint = endpoint

self._model_description = None
self._context_window = None
self._generator = None
self._client = None
self._model = None
self.load()

def load(self) -> None:
try:
from xinference.client import RESTfulClient
except ImportError:
raise ImportError(
"Could not import Xinference library."
'Please install Xinference with `pip install "xinference[all]"`'
)

self._client = RESTfulClient(self.endpoint)
self._generator = self._client.get_model(self.model_uid)
self._model_description = self._client.list_models()[self.model_uid]

self._model = self._model_description["model_name"]
self._context_window = xinference_modelname_to_contextsize(self._model)

@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
return LLMMetadata(
context_window=int(self._context_window // TOKEN_RATIO),
num_output=DEFAULT_NUM_OUTPUTS,
model_name=self._model,
)

@property
def _model_kwargs(self) -> Dict[str, Any]:
base_kwargs = {
"temperature": self.temperature,
"max_length": self._context_window,
}
model_kwargs = {
**base_kwargs,
**self._model_description,
}
return model_kwargs

def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
return {"prompt": prompt, **self._model_kwargs, **kwargs}

def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = messages[-1].content if len(messages) > 0 else ""
history = [xinference_message_to_history(message) for message in messages[:-1]]
response_text = self._generator.chat(
prompt=prompt,
chat_history=history,
generate_config={"stream": False, "temperature": self.temperature},
)["choices"][0]["message"]["content"]
response = ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=response_text,
),
delta=None,
)
return response

def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = messages[-1].content if len(messages) > 0 else ""
history = [xinference_message_to_history(message) for message in messages[:-1]]
response_iter = self._generator.chat(
prompt=prompt,
chat_history=history,
generate_config={"stream": True, "temperature": self.temperature},
)

def gen() -> ChatResponseGen:
text = ""
for c in response_iter:
delta = c["choices"][0]["delta"].get("content", "")
text += delta
yield ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=text,
),
delta=delta,
)

return gen()

def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
response_text = self._generator.chat(
prompt=prompt,
chat_history=None,
generate_config={"stream": False, "temperature": self.temperature},
)["choices"][0]["message"]["content"]
response = CompletionResponse(
delta=None,
text=response_text,
)
return response

def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
response_iter = self._generator.chat(
prompt=prompt,
chat_history=None,
generate_config={"stream": True, "temperature": self.temperature},
)

def gen() -> CompletionResponseGen:
text = ""
for c in response_iter:
delta = c["choices"][0]["delta"].get("content", "")
text += delta
yield CompletionResponse(
delta=delta,
text=text,
)

return gen()
37 changes: 37 additions & 0 deletions llama_index/llms/xinference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing_extensions import NotRequired, TypedDict

from llama_index.llms.base import ChatMessage

XINFERENCE_MODEL_SIZES = {
"baichuan": 2048,
"baichuan-chat": 2048,
"wizardlm-v1.0": 2048,
"vicuna-v1.3": 2048,
"orca": 2048,
"chatglm": 2048,
"chatglm2": 8192,
"llama-2-chat": 4096,
"llama-2": 4096,
}


class ChatCompletionMessage(TypedDict):
role: str
content: str
user: NotRequired[str]


def xinference_message_to_history(message: ChatMessage) -> ChatCompletionMessage:
return ChatCompletionMessage(role=str(message.role), content=message.content)


def xinference_modelname_to_contextsize(modelname: str) -> int:
context_size = XINFERENCE_MODEL_SIZES.get(modelname, None)

if context_size is None:
raise ValueError(
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
"Known models are: " + ", ".join(XINFERENCE_MODEL_SIZES.keys())
)

return context_size
Loading