-
Notifications
You must be signed in to change notification settings - Fork 767
FEAT: Support Ctransformers #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 57 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
77d4954
fix rebase
RayJi01 3822252
small edit
RayJi01 09d9e8e
fix lint
RayJi01 f218a7d
fix model family match
RayJi01 99b6117
fix model family match
RayJi01 c4cf3b9
small edit on ctransformer.py
RayJi01 721a540
small edit on ctransformer.py
RayJi01 87333ef
add model-type list
RayJi01 115db12
edit llama cpp
RayJi01 e291cb1
edit llama cpp
RayJi01 9b33537
edit llama cpp
RayJi01 2b1159f
edit ctransformer cpp
RayJi01 9346de1
edit ctransformer cpp
RayJi01 2edc687
edit ctransformer cpp
RayJi01 8189c23
edit ctransformer cpp
RayJi01 153ae6c
edit ctransformer cpp
RayJi01 ff93f98
edit ctransformer cpp
RayJi01 86335c4
edit ctransformer util
RayJi01 ce6bb78
edit ctransformer util
RayJi01 3db942b
edit ctransformer util
RayJi01 72a77f3
edit ctransformer util
RayJi01 fef3a69
Ctransformer Pipeline is clear and ready to serve
RayJi01 2345dae
fix rebase
RayJi01 ae6e5db
fix setup issue2.
RayJi01 4db685b
fix setup issue3
RayJi01 5ae9ec7
fix import CI checking
RayJi01 6330d95
fix import CI checking 2
RayJi01 fd6b70a
fix import CI checking 3
RayJi01 22ce526
fix import CI checking 4
RayJi01 e1711d5
fix import CI checking 5
RayJi01 ae04adb
add test to c-transformers
RayJi01 7641b30
fix test typing issue from importing Ctransforemrs
RayJi01 30575ed
add pytest-mock to setup.cfg
RayJi01 4b508a0
fix ctransformers import issue1
RayJi01 98669fd
fix ctransformers import by add ctransformers to dev
RayJi01 31c4054
fix ctransformers test issue2
RayJi01 6472c84
fix ctransformers test issue2
RayJi01 5b7d837
refactor toward suggestions.
RayJi01 ac2799c
fix test dependency
RayJi01 1c7ae94
update logger error
RayJi01 58ca677
update logger error
RayJi01 8955272
update logger error
RayJi01 f909238
update logger error
RayJi01 e76c576
fix issue in ctransformersutil and delete redundant logger.
RayJi01 cfb267e
refactor toward suggestions.
RayJi01 9bca70c
fix part of issues from suggestions.
RayJi01 cb66cd3
fix part of issues from suggestions.
RayJi01 5ffa5fe
fix test issue by remove mock autoconfig.
RayJi01 f019ace
refactor toward suggestions.
RayJi01 4f962f1
remove parameterize in test, only test for q4_0
RayJi01 973685b
try smaller gpt-2 model for test ctransformer
RayJi01 e09350d
add CTRANSFORME_SUPPORT_MODEL constant (can be expanded)
RayJi01 016c318
fix cuda branch error.
RayJi01 2524747
fix test cuda error.
RayJi01 740f31a
add GPU check to make sure only supported model can initialize Cuda.
RayJi01 25d3dd4
add GPU check to make sure only supported model can initialize Cuda.
RayJi01 a9e3006
try adding mpt model.
RayJi01 ef9a84d
remove mpt for this pr.
RayJi01 0f05689
refactor toward suggestions.
RayJi01 360f7e7
fix lint issue.
RayJi01 bfecc8a
remove prompt style for generate model.
RayJi01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,277 @@ | ||
| # Copyright 2022-2023 XProbe Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import os | ||
| from typing import TYPE_CHECKING, Iterator, Optional, Sequence, TypedDict, Union | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ctransformers import AutoConfig | ||
|
|
||
| from ....types import Completion, CompletionChunk | ||
| from ..core import LLM | ||
| from ..llm_family import LLMFamilyV1, LLMSpecV1 | ||
| from .ctransformers_util import generate_stream | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # all supported models for Ctransformers with their model type. | ||
| # Please Strictly follows this name format when inputting new model to model_family. | ||
| MODEL_TYPE_FOR_CTRANSFORMERS = { | ||
| "gpt-2": "gpt2", | ||
| "gpt-j": "gptj", | ||
| "gpt4all-j": "gptj", | ||
| "gpt-neox": "gpt_neox", | ||
| "stablelm": "gpt_neox", | ||
| "llama": "llama", | ||
| "llama-2": "llama", | ||
| "mpt": "mpt", | ||
| "dolly-v2": "dolly-v2", | ||
| "replit": "replit", | ||
| "starcoder": "starcoder", | ||
| "starchat": "starcoder", | ||
| "falcon": "falcon", | ||
| } | ||
|
|
||
| # these two constants subjects to change for future development and ctransformers updates. | ||
| CTRANSFORMERS_SUPPORTED_MODEL = ["starcoder", "gpt-2", "mpt"] | ||
|
|
||
| CTRANSFORMERS_GPU_SUPPORT = ["llama", "llama-2", "mpt", "falcon"] | ||
|
|
||
| SIZE_TO_GPU_LAYERS = { | ||
| 3: 26, | ||
| 7: 32, | ||
| 13: 40, | ||
| 30: 60, | ||
| 65: 80, | ||
| } | ||
|
|
||
|
|
||
| class CtransformersModelConfig(TypedDict, total=False): | ||
| n_ctx: int | ||
| n_gpu_layers: int | ||
|
|
||
|
|
||
| class CtransformersGenerateConfig(TypedDict, total=False): | ||
| max_tokens: Optional[int] | ||
| top_k: Optional[int] | ||
| top_p: Optional[float] | ||
| temperature: Optional[float] | ||
| repetition_penalty: Optional[float] | ||
| last_n_tokens: Optional[int] | ||
| seed: Optional[int] | ||
| batch_size: Optional[int] | ||
| threads: Optional[int] | ||
| stop: Optional[Sequence[str]] | ||
| stream: Optional[bool] | ||
| reset: Optional[bool] | ||
|
|
||
|
|
||
| def _has_cuda_device(): | ||
| from xorbits._mars.resource import cuda_count | ||
|
|
||
| return cuda_count() > 0 | ||
|
|
||
|
|
||
| class CtransformersModel(LLM): | ||
| def __init__( | ||
| self, | ||
| model_uid: str, | ||
| model_family: "LLMFamilyV1", | ||
| model_spec: "LLMSpecV1", | ||
| quantization: str, | ||
| model_path: str, | ||
| ctransformers_Model_Config: Optional[CtransformersModelConfig], | ||
| ): | ||
| super().__init__(model_uid, model_family, model_spec, quantization, model_path) | ||
|
|
||
| self._model_type = None | ||
| closest_size = min( | ||
| SIZE_TO_GPU_LAYERS.keys(), | ||
| key=lambda x: abs(x - model_spec.model_size_in_billions), | ||
| ) | ||
|
|
||
| self._model_family = model_family | ||
| self._model_uid = model_uid | ||
| self._llm = None | ||
|
|
||
| self._gpu_layers = SIZE_TO_GPU_LAYERS[closest_size] | ||
RayJi01 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._ctransformer_model_config = self._sanitize_model_config( | ||
| model_path, ctransformers_Model_Config | ||
| ) | ||
|
|
||
| def _sanitize_model_config( | ||
RayJi01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, model_path, ctransformers_model_config: Optional[CtransformersModelConfig] | ||
| ) -> "AutoConfig": | ||
| try: | ||
| from ctransformers import AutoConfig, Config | ||
| except ImportError: | ||
| error_message = ( | ||
| "Failed to import module 'ctransformers - AutoConfig and Config'" | ||
| ) | ||
|
|
||
| installation_guide = [ | ||
| f"Please make sure 'ctransformers' is installed.", | ||
| f"You can install it by checking out the repository for command:" | ||
| f"https://github.com/marella/ctransformers", | ||
| ] | ||
|
|
||
| raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") | ||
|
|
||
| # if the model have customized config, we update it. | ||
| ctransformers_model_config_returned = Config() | ||
| potential_gpu_layers = None | ||
| if ctransformers_model_config: | ||
| potential_context_length = ctransformers_model_config.pop("n_ctx", None) | ||
| potential_gpu_layers = ctransformers_model_config.pop("n_gpu_layers", None) | ||
|
|
||
| ctransformers_model_config_returned.context_length = ( | ||
RayJi01 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| potential_context_length | ||
| ) | ||
| ctransformers_model_config_returned.gpu_layers = potential_gpu_layers | ||
|
|
||
| # if user does not define gpu layers, we have to set it with our system if applicable. | ||
| if potential_gpu_layers is None: | ||
| if self._model_family.model_name not in CTRANSFORMERS_GPU_SUPPORT: | ||
| ctransformers_model_config_returned.gpu_layers = -1 | ||
| elif self._is_darwin_and_apple_silicon(): | ||
| ctransformers_model_config_returned.gpu_layers = 1 | ||
| elif _has_cuda_device(): | ||
| ctransformers_model_config_returned.gpu_layers = self._gpu_layers | ||
|
|
||
| return AutoConfig(ctransformers_model_config_returned) | ||
|
|
||
| def _sanitize_generate_config( | ||
| self, | ||
| ctransformers_generate_config: Optional[CtransformersGenerateConfig], | ||
| ) -> CtransformersGenerateConfig: | ||
| # if the input config is not None, we try to copy the selected attributes to the ctransformersGenerateConfig. | ||
| if ctransformers_generate_config is None: | ||
| ctransformers_generate_config = CtransformersGenerateConfig() | ||
|
|
||
| # for our system, the threads will have to be set to 4 | ||
| # all other parameters, if not specified, will be set to default when generate. | ||
| ctransformers_generate_config.setdefault("threads", 4) | ||
|
|
||
| return ctransformers_generate_config | ||
|
|
||
| def load(self): | ||
| try: | ||
| from ctransformers import AutoModelForCausalLM | ||
| except ImportError: | ||
| error_message = "Failed to import module 'ctransformers'" | ||
|
|
||
| installation_guide = [ | ||
| f"Please make sure 'ctransformers' is installed.", | ||
| f"You can install it by checking out the repository for command." | ||
| f"https://github.com/marella/ctransformers", | ||
| ] | ||
|
|
||
| raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") | ||
|
|
||
| model_path = os.path.join( | ||
| self.model_path, | ||
| self.model_spec.model_file_name_template.format( | ||
| quantization=self.quantization | ||
| ), | ||
| ) | ||
|
|
||
| self._model_type = self._determine_model_type() | ||
| self._llm = AutoModelForCausalLM.from_pretrained( | ||
| model_path_or_repo_id=model_path, | ||
| model_type=self._model_type, | ||
| config=self._ctransformer_model_config, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def match(cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1) -> bool: | ||
| if llm_spec.model_format != "ggmlv3": | ||
| return False | ||
| if llm_family.model_name not in CTRANSFORMERS_SUPPORTED_MODEL: | ||
| return False | ||
| if "generate" not in llm_family.model_ability: | ||
| return False | ||
| return True | ||
|
|
||
| def _determine_model_type(self): | ||
| if self._model_family.model_name not in MODEL_TYPE_FOR_CTRANSFORMERS: | ||
| raise ValueError( | ||
RayJi01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"The current model {self._model_family.model_name} is not supported, check your model name. " | ||
| ) | ||
| return MODEL_TYPE_FOR_CTRANSFORMERS[self._model_family.model_name] | ||
|
|
||
| def generate( | ||
| self, prompt: str, generate_config_raw: CtransformersGenerateConfig | ||
| ) -> Union[Completion, Iterator[CompletionChunk]]: | ||
| def generator_wrapper( | ||
| _prompt: str, | ||
| _max_new_tokens: Union[int, None], | ||
| _generate_config: CtransformersGenerateConfig, | ||
| ) -> Iterator[CompletionChunk]: | ||
| assert self._model_uid is not None | ||
| for _completion_chunk, _ in generate_stream( | ||
| model=self._model_uid, | ||
| model_ref=self._llm, | ||
| prompt=_prompt, | ||
| max_new_tokens=_max_new_tokens, | ||
| **_generate_config, | ||
| ): | ||
| yield _completion_chunk | ||
|
|
||
| generate_config = self._sanitize_generate_config(generate_config_raw) | ||
| max_new_tokens = generate_config.pop("max_tokens", None) | ||
|
|
||
| logger.debug( | ||
| "Enter generate, prompt: %s, generate config: %s", prompt, generate_config | ||
| ) | ||
RayJi01 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| stream_or_not = generate_config.get("stream", False) | ||
| if stream_or_not: | ||
| return generator_wrapper( | ||
| _prompt=prompt, | ||
| _max_new_tokens=max_new_tokens, | ||
| _generate_config=generate_config, | ||
| ) | ||
| else: | ||
| assert self.model_uid is not None | ||
| completion_chunk = None | ||
| completion_usage = None | ||
| for completion_chunk, completion_usage in generate_stream( | ||
| model=self.model_uid, | ||
| model_ref=self._llm, | ||
| prompt=prompt, | ||
| max_new_tokens=max_new_tokens, | ||
| **generate_config, | ||
| ): | ||
| pass | ||
|
|
||
| assert completion_chunk is not None | ||
| assert completion_usage is not None | ||
|
|
||
| completion = Completion( | ||
| id=completion_chunk["id"], | ||
| object=completion_chunk["object"], | ||
| created=completion_chunk["created"], | ||
| model=completion_chunk["model"], | ||
| choices=completion_chunk["choices"], | ||
| usage=completion_usage, | ||
| ) | ||
|
|
||
| logger.debug( | ||
| "Generated, completion: %s, generate config: %s", | ||
| completion, | ||
| generate_config, | ||
| ) | ||
|
|
||
| return completion | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.