-
Notifications
You must be signed in to change notification settings - Fork 6.6k
feat: Add Xorbits Inference for local deployment #7151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
logan-markewich
merged 13 commits into
run-llama:main
from
Bojun-Feng:feat/support_xinference
Aug 9, 2023
Merged
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3ea8fca
add xinference class
Bojun-Feng 385de76
add utils
Bojun-Feng 799979a
fix lint formatting
Bojun-Feng f54d94d
add notebook demo
Bojun-Feng b68a162
fix lint formatting
Bojun-Feng 96ab188
add model list to demo
Bojun-Feng 8a6e076
fix lint in demo notebook
Bojun-Feng 4d7be05
Merge branch 'main' into feat/support_xinference
logan-markewich 75e0f7e
add callbacks
logan-markewich 108204a
add to docs
logan-markewich 5bcf00d
changelog
logan-markewich 109ab27
linting
logan-markewich 0b473d7
fix tests
logan-markewich File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
189 changes: 189 additions & 0 deletions
189
examples/paul_graham_essay/XinferenceLocalDeployment.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "7096589b-daaf-440a-b89d-b4956f2db4b2", | ||
| "metadata": { | ||
| "tags": [] | ||
| }, | ||
| "source": [ | ||
| "# Using Xorbits Inference to Deploy Local LLMs - in 3 steps!\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "d8cfbe6f-4c50-4c4f-90f9-03bb91201ef5", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## <span style=\"font-size: xx-large;;\">🤖 </span> Installing and Running Xorbits Inference (1/3)\n", | ||
| "\n", | ||
| "#### i. Run `pip install \"xinference[all]\"` in a terminal window\n", | ||
| "\n", | ||
| "#### ii. After installation is complete, restart this jupyter notebook\n", | ||
| "\n", | ||
| "#### iii. Run `xinference` in a new terminal window\n", | ||
| "\n", | ||
| "#### iv. You should see something similar to the following output:\n", | ||
| "\n", | ||
| "```\n", | ||
| "INFO:xinference:Xinference successfully started. Endpoint: http://127.0.0.1:9997\n", | ||
| "INFO:xinference.core.service:Worker 127.0.0.1:21561 has been added successfully\n", | ||
| "INFO:xinference.deploy.worker:Xinference worker successfully started.\n", | ||
| "```\n", | ||
| "\n", | ||
| "#### v. In the endpoint description, locate the endpoint port number after the colon. In the above case it is `9997`\n", | ||
| "\n", | ||
| "#### vi. Paste the endpoint port number in the following cell" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "c1217d40", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "port = 9997 # replace with your endpoint port number" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "c96ea31e", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## <span style=\"font-size: xx-large;;\">🚀 </span> Downloading and Launching Local Models (2/3)\n", | ||
| "\n", | ||
| "#### In this step, simply run the following code blocks\n", | ||
| "\n", | ||
| "#### Also, feel free to change the model size and quantization for different experiences!\n", | ||
| "\n", | ||
| "#### A complete list of supported models can be found in Xorbits Inference's [official GitHub page](https://github.com/xorbitsai/inference/blob/main/README.md)\n", | ||
| "\n", | ||
| "Here are the parameter options for vicuna-v1.3, ranked from the quickest and least space-consuming to the most resource-intensive but high-performing.\n", | ||
| "\n", | ||
| "model_size_in_billions: 7, 13, 33\n", | ||
| "\n", | ||
| "quantization: \"q2_K\", \"q3_K_L\", \"q3_K_M\", \"q3_K_S\", \"q4_0\", \"q4_1\", \"q4_K_M\", \"q4_K_S\", \"q5_0\", \"q5_1\", \"q5_K_M\", \"q5_K_S\", \"q6_K\", \"q8_0\"\n", | ||
| "\n", | ||
| "In order to achieve satisfactory results, it is recommended to use models above 13 billion in size" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "7248be01", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# If Xinference can not be imported, you may need to restart jupyter notebook\n", | ||
| "from llama_index import (\n", | ||
| " ListIndex,\n", | ||
| " TreeIndex,\n", | ||
| " VectorStoreIndex,\n", | ||
| " KeywordTableIndex,\n", | ||
| " KnowledgeGraphIndex,\n", | ||
| " SimpleDirectoryReader,\n", | ||
| " ServiceContext,\n", | ||
| ")\n", | ||
| "from llama_index.llms import Xinference\n", | ||
| "from xinference.client import RESTfulClient\n", | ||
| "from IPython.display import Markdown, display" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b48c6d7a-7a38-440b-8ecb-f43f9050ee54", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Define a client to send commands to xinference\n", | ||
| "client = RESTfulClient(f\"http://localhost:{port}\")\n", | ||
| "\n", | ||
| "# Download and Launch a model, this may take a while the first time\n", | ||
| "model_uid = client.launch_model(\n", | ||
| " model_name = \"vicuna-v1.3\",\n", | ||
| " model_size_in_billions = 7,\n", | ||
| " model_format = \"ggmlv3\",\n", | ||
| " quantization = \"q2_K\",\n", | ||
| ")\n", | ||
| "\n", | ||
| "llm = Xinference(endpoint = f\"http://localhost:{port}\", model_uid = model_uid)\n", | ||
| "service_context = ServiceContext.from_defaults(llm=llm)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "510ba348", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## <span style=\"font-size: xx-large;;\">🕺 </span> Index the Data and Start Chatting! (3/3)\n", | ||
| "\n", | ||
| "#### In this step, simply run the following code blocks\n", | ||
| "\n", | ||
| "#### Also, feel free to change the index that is used for different experiences\n", | ||
| "\n", | ||
| "#### A list of all available indexes can be found in Llama Index's [official Docs](https://gpt-index.readthedocs.io/en/latest/core_modules/data_modules/index/modules.html)\n", | ||
| "\n", | ||
| "Here are some available indexes that are imported:\n", | ||
| "\n", | ||
| "`ListIndex`, `TreeIndex`, `VetorStoreIndex`, `KeywordTableIndex`, `KnowledgeGraphIndex`\n", | ||
| "\n", | ||
| "The following code uses `VetorStoreIndex`. To change index, simply replace its name with another index" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "708b323e-d314-4b83-864b-22a1ead60de9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# create index from the data\n", | ||
| "documents = SimpleDirectoryReader(\"../paul_graham_essay/data\").load_data()\n", | ||
| "\n", | ||
| "# change index name in the following line\n", | ||
| "index = VectorStoreIndex.from_documents(\n", | ||
| " documents=documents, service_context=service_context\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "2c2de13b-133f-404e-9661-2acafcdf2573", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# ask a question and display the answer\n", | ||
| "query_engine = index.as_query_engine()\n", | ||
| "\n", | ||
| "question = \"What did the author do after his time at Y Combinator?\"\n", | ||
| "\n", | ||
| "response = query_engine.query(question)\n", | ||
| "display(Markdown(f\"<b>{response}</b>\"))" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.11.3" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| from typing import Any, Dict, Sequence | ||
|
|
||
| from llama_index.constants import DEFAULT_NUM_OUTPUTS | ||
| from llama_index.llms.base import ( | ||
| ChatMessage, | ||
| ChatResponse, | ||
| ChatResponseGen, | ||
| CompletionResponse, | ||
| CompletionResponseGen, | ||
| LLMMetadata, | ||
| MessageRole, | ||
| ) | ||
| from llama_index.llms.custom import CustomLLM | ||
| from llama_index.llms.xinference_utils import ( | ||
| xinference_message_to_history, | ||
| xinference_modelname_to_contextsize, | ||
| ) | ||
|
|
||
| # an approximation of the ratio between llama and GPT2 tokens | ||
| TOKEN_RATIO = 2.5 | ||
|
|
||
|
|
||
| class Xinference(CustomLLM): | ||
| def __init__( | ||
| self, | ||
| model_uid: str, | ||
| endpoint: str, | ||
| temperature: float = 1.0, | ||
| ) -> None: | ||
| self.temperature = temperature | ||
| self.model_uid = model_uid | ||
| self.endpoint = endpoint | ||
|
|
||
| self._model_description = None | ||
| self._context_window = None | ||
| self._generator = None | ||
| self._client = None | ||
| self._model = None | ||
| self.load() | ||
|
|
||
| def load(self) -> None: | ||
| try: | ||
| from xinference.client import RESTfulClient | ||
| except ImportError: | ||
| raise ImportError( | ||
| "Could not import Xinference library." | ||
| 'Please install Xinference with `pip install "xinference[all]"`' | ||
| ) | ||
|
|
||
| self._client = RESTfulClient(self.endpoint) | ||
| self._generator = self._client.get_model(self.model_uid) | ||
| self._model_description = self._client.list_models()[self.model_uid] | ||
|
|
||
| self._model = self._model_description["model_name"] | ||
| self._context_window = xinference_modelname_to_contextsize(self._model) | ||
|
|
||
| @property | ||
| def metadata(self) -> LLMMetadata: | ||
| """LLM metadata.""" | ||
| return LLMMetadata( | ||
| context_window=int(self._context_window // TOKEN_RATIO), | ||
| num_output=DEFAULT_NUM_OUTPUTS, | ||
| model_name=self._model, | ||
| ) | ||
|
|
||
| @property | ||
| def _model_kwargs(self) -> Dict[str, Any]: | ||
| base_kwargs = { | ||
| "temperature": self.temperature, | ||
| "max_length": self._context_window, | ||
| } | ||
| model_kwargs = { | ||
| **base_kwargs, | ||
| **self._model_description, | ||
| } | ||
| return model_kwargs | ||
|
|
||
| def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: | ||
| return {"prompt": prompt, **self._model_kwargs, **kwargs} | ||
|
|
||
| def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | ||
| prompt = messages[-1].content if len(messages) > 0 else "" | ||
| history = [xinference_message_to_history(message) for message in messages[:-1]] | ||
| response_text = self._generator.chat( | ||
| prompt=prompt, | ||
| chat_history=history, | ||
| generate_config={"stream": False, "temperature": self.temperature}, | ||
| )["choices"][0]["message"]["content"] | ||
| response = ChatResponse( | ||
| message=ChatMessage( | ||
| role=MessageRole.ASSISTANT, | ||
| content=response_text, | ||
| ), | ||
| delta=None, | ||
| ) | ||
| return response | ||
|
|
||
| def stream_chat( | ||
| self, messages: Sequence[ChatMessage], **kwargs: Any | ||
| ) -> ChatResponseGen: | ||
| prompt = messages[-1].content if len(messages) > 0 else "" | ||
| history = [xinference_message_to_history(message) for message in messages[:-1]] | ||
| response_iter = self._generator.chat( | ||
| prompt=prompt, | ||
| chat_history=history, | ||
| generate_config={"stream": True, "temperature": self.temperature}, | ||
| ) | ||
|
|
||
| def gen() -> ChatResponseGen: | ||
| text = "" | ||
| for c in response_iter: | ||
| delta = c["choices"][0]["delta"].get("content", "") | ||
| text += delta | ||
| yield ChatResponse( | ||
| message=ChatMessage( | ||
| role=MessageRole.ASSISTANT, | ||
| content=text, | ||
| ), | ||
| delta=delta, | ||
| ) | ||
|
|
||
| return gen() | ||
|
|
||
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
| response_text = self._generator.chat( | ||
| prompt=prompt, | ||
| chat_history=None, | ||
| generate_config={"stream": False, "temperature": self.temperature}, | ||
| )["choices"][0]["message"]["content"] | ||
| response = CompletionResponse( | ||
| delta=None, | ||
| text=response_text, | ||
| ) | ||
| return response | ||
|
|
||
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | ||
| response_iter = self._generator.chat( | ||
| prompt=prompt, | ||
| chat_history=None, | ||
| generate_config={"stream": True, "temperature": self.temperature}, | ||
| ) | ||
|
|
||
| def gen() -> CompletionResponseGen: | ||
| text = "" | ||
| for c in response_iter: | ||
| delta = c["choices"][0]["delta"].get("content", "") | ||
| text += delta | ||
| yield CompletionResponse( | ||
| delta=delta, | ||
| text=text, | ||
| ) | ||
|
|
||
| return gen() | ||
Bojun-Feng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from typing_extensions import NotRequired, TypedDict | ||
|
|
||
| from llama_index.llms.base import ChatMessage | ||
|
|
||
| XINFERENCE_MODEL_SIZES = { | ||
| "baichuan": 2048, | ||
| "baichuan-chat": 2048, | ||
| "wizardlm-v1.0": 2048, | ||
| "vicuna-v1.3": 2048, | ||
| "orca": 2048, | ||
| "chatglm": 2048, | ||
| "chatglm2": 8192, | ||
| "llama-2-chat": 4096, | ||
| "llama-2": 4096, | ||
| } | ||
|
|
||
|
|
||
| class ChatCompletionMessage(TypedDict): | ||
| role: str | ||
| content: str | ||
| user: NotRequired[str] | ||
|
|
||
|
|
||
| def xinference_message_to_history(message: ChatMessage) -> ChatCompletionMessage: | ||
| return ChatCompletionMessage(role=str(message.role), content=message.content) | ||
|
|
||
|
|
||
| def xinference_modelname_to_contextsize(modelname: str) -> int: | ||
| context_size = XINFERENCE_MODEL_SIZES.get(modelname, None) | ||
|
|
||
| if context_size is None: | ||
| raise ValueError( | ||
| f"Unknown model: {modelname}. Please provide a valid OpenAI model name." | ||
| "Known models are: " + ", ".join(XINFERENCE_MODEL_SIZES.keys()) | ||
| ) | ||
|
|
||
| return context_size |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.