Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions backends/base_model_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
import asyncio
import pathlib
from loguru import logger
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)
from typing import Any, AsyncIterator
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
Expand All @@ -21,7 +15,7 @@ class BaseModelContainer(abc.ABC):

# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
prompt_template: PromptTemplate | None = None

# HF Model instance
hf_model: HFModel
Expand All @@ -34,7 +28,7 @@ class BaseModelContainer(abc.ABC):
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
active_job_ids: dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock
load_condition: asyncio.Condition
Expand Down Expand Up @@ -98,7 +92,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
pass

@abc.abstractmethod
def encode_tokens(self, text: str, **kwargs) -> List[int]:
def encode_tokens(self, text: str, **kwargs) -> list[int]:
"""
Encodes a string of text into a list of token IDs.

Expand All @@ -113,7 +107,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]:
pass

@abc.abstractmethod
def decode_tokens(self, ids: List[int], **kwargs) -> str:
def decode_tokens(self, ids: list[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.

Expand All @@ -128,7 +122,7 @@ def decode_tokens(self, ids: List[int], **kwargs) -> str:
pass

@abc.abstractmethod
def get_special_tokens(self) -> Dict[str, Any]:
def get_special_tokens(self) -> dict[str, Any]:
"""
Gets special tokens used by the model/tokenizer.

Expand Down Expand Up @@ -164,7 +158,7 @@ async def wait_for_jobs(self, skip_wait: bool = False):
# Optional methods
async def load_loras(
self, lora_directory: pathlib.Path, **kwargs
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
"""
Loads LoRA adapters. Base implementation does nothing or raises error.

Expand All @@ -184,7 +178,7 @@ async def load_loras(
],
}

def get_loras(self) -> List[Any]:
def get_loras(self) -> list[Any]:
"""
Gets the currently loaded LoRA adapters. Base implementation returns empty list.

Expand All @@ -200,9 +194,9 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.

Expand All @@ -225,9 +219,9 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.

Expand Down
7 changes: 3 additions & 4 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import traceback
import typing
from functools import lru_cache
from typing import List
from typing import Any

import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
Expand All @@ -16,7 +15,7 @@
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""

filters: List[ExLlamaV2Filter]
filters: list[ExLlamaV2Filter]

def __init__(self):
self.filters = []
Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(self, nonterminal: str, kbnf_string: str):
self.kbnf_string = kbnf_string

# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
def extract(self, input_str: str) -> tuple[str, Any] | None:
return "", input_str

@property
Expand Down
88 changes: 44 additions & 44 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from itertools import zip_longest
from loguru import logger
from typing import Dict, List, Optional

from backends.base_model_container import BaseModelContainer
from backends.exllamav2.grammar import (
Expand Down Expand Up @@ -58,45 +57,45 @@ class ExllamaV2Container(BaseModelContainer):
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
prompt_template: PromptTemplate | None = None

# HF model instance
hf_model: HFModel

# Exl2 vars
config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
config: ExLlamaV2Config | None = None
model: ExLlamaV2 | None = None
cache: ExLlamaV2Cache | None = None
tokenizer: ExLlamaV2Tokenizer | None = None
generator: ExLlamaV2DynamicGeneratorAsync | None = None
prompt_template: PromptTemplate | None = None
paged: bool = True

# Draft model vars
use_draft_model: bool = False
draft_config: Optional[ExLlamaV2Config] = None
draft_model: Optional[ExLlamaV2] = None
draft_cache: Optional[ExLlamaV2Cache] = None
draft_config: ExLlamaV2Config | None = None
draft_model: ExLlamaV2 | None = None
draft_cache: ExLlamaV2Cache | None = None

# Internal config vars
cache_size: int = None
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None
max_batch_size: int | None = None

# GPU split vars
gpu_split: List[float] = []
draft_gpu_split: List[float] = []
gpu_split: list[float] = []
draft_gpu_split: list[float] = []
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
autosplit_reserve: list[float] = [96 * 1024**2]
use_tp: bool = False

# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
vision_model: ExLlamaV2VisionTower | None = None

# Load synchronization
active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
active_job_ids: dict[str, ExLlamaV2DynamicJobAsync | None] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
Expand Down Expand Up @@ -130,7 +129,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

# Set vision state and error if vision isn't supported on the current model
# set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
raise ValueError(
Expand Down Expand Up @@ -185,12 +184,12 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
gpu_split = unwrap(kwargs.get("gpu_split"), [])
gpu_device_list = list(range(0, gpu_count))

# Set GPU split options
# set GPU split options
if gpu_count == 1:
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
else:
# Set tensor parallel
# set tensor parallel
if use_tp:
self.use_tp = True

Expand Down Expand Up @@ -233,7 +232,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
# Hardcode max output length to 16
self.config.max_output_len = 16

# Set max batch size to the config override
# set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))

# Check whether the user's configuration supports flash/paged attention
Expand Down Expand Up @@ -262,7 +261,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
# Grab user-set max seq len
user_max_seq_len = kwargs.get("max_seq_len")

# Set k/v cache size
# set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
user_cache_size = coalesce(kwargs.get("cache_size"), user_max_seq_len, 4096)
Expand All @@ -272,8 +271,9 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
self.config.max_seq_len = unwrap(
user_max_seq_len, min(hf_model.hf_config.max_position_embeddings, 4096)
)
self.cache_size = self.config.max_seq_len

# Set the rope scale
# set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
Expand Down Expand Up @@ -322,15 +322,15 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2

# Set user-configured draft model values
# set user-configured draft model values
if self.use_draft_model:
self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)

# Set draft rope alpha. Follows same behavior as model rope alpha.
# set draft rope alpha. Follows same behavior as model rope alpha.
# Use the max_position_embeddings of the model
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
Expand All @@ -341,7 +341,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
else:
self.draft_config.scale_alpha_value = draft_rope_alpha

# Set draft cache mode
# set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")

# Catch exllamav3 draft_cache_mode
Expand Down Expand Up @@ -750,9 +750,9 @@ async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))

loras_to_load: List[ExLlamaV2Lora] = []
success: List[str] = []
failure: List[str] = []
loras_to_load: list[ExLlamaV2Lora] = []
success: list[str] = []
failure: list[str] = []

for lora in loras:
lora_name = lora.get("name")
Expand Down Expand Up @@ -836,7 +836,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
await self.generator.close()
self.generator = None

# Set all model state variables to False
# set all model state variables to False
self.loaded = False

gc.collect()
Expand Down Expand Up @@ -869,7 +869,7 @@ def encode_tokens(self, text: str, **kwargs):
.tolist()
)

def decode_tokens(self, ids: List[int], **kwargs):
def decode_tokens(self, ids: list[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""

ids = torch.tensor([ids])
Expand Down Expand Up @@ -908,8 +908,8 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""Generate a response to a prompt."""
generations = []
Expand Down Expand Up @@ -969,8 +969,8 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
try:
# Wait for load lock to be freed before processing
Expand Down Expand Up @@ -1136,15 +1136,15 @@ def assign_gen_params(
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
)

# Set banned tokens
# set banned tokens
if params.banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)

# Set allowed tokens
# set allowed tokens
if params.allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)

# Set logit bias
# set logit bias
if params.logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None:
Expand Down Expand Up @@ -1242,8 +1242,8 @@ async def generate_gen(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""
Create generator function for prompt completion.
Expand All @@ -1261,7 +1261,7 @@ async def generate_gen(
grammar_handler,
)

# Set banned strings
# set banned strings
banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
Expand All @@ -1271,7 +1271,7 @@ async def generate_gen(

banned_strings = []

# Set CFG scale and negative prompt
# set CFG scale and negative prompt
cfg_scale = params.cfg_scale
negative_prompt = None
if cfg_scale not in [None, 1.0]:
Expand Down Expand Up @@ -1301,15 +1301,15 @@ async def generate_gen(
stop_conditions = params.stop
ban_eos_token = params.ban_eos_token

# Set add_bos_token for generation
# set add_bos_token for generation
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())

# Fetch EOS tokens from the HF model if they exist
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]

# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
# set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
else:
Expand Down
Loading
Loading