Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm_spyre.v1.core.scheduler."\
"StaticBatchingSpyreScheduler")

# Hardcode some things for granite-3.3-8b-instruct
if cls.is_granite_3_8b(vllm_config.model_config):
cls.configure_granite_3_8b(vllm_config)

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
# so that the scheduler thinks an entire sequence will fit in
Expand All @@ -188,15 +192,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config.max_num_batched_tokens = (
model_config.max_model_len * scheduler_config.max_num_seqs)
else:
# TODO: ideally, this would be user-configurable from CLI/engine
# args instead of with the internal env var, but that requires a
# way to detect if value set by vllm or by the user
if (chunk_len := os.getenv("VLLM_DT_CHUNK_LEN")) is None:
os.environ["VLLM_DT_CHUNK_LEN"] = \
str(scheduler_config.max_num_batched_tokens)
else:
try:
chunk_len_int = int(chunk_len)
except (ValueError, TypeError) as e:
raise Exception(
"VLLM_DT_CHUNK_LEN must be an integer") from e
scheduler_config.max_num_batched_tokens = chunk_len_int

assert scheduler_config.max_num_batched_tokens % \
cls._block_size == 0, ("`max_num_batched_tokens` must"
f" be divisible by the block size ({cls._block_size}) "
"to enable chunked prefill. It was set to "
f"`{scheduler_config.max_num_batched_tokens}`. Please "
"set `--max-num-batched-tokens` to a number that satisfy "
"this constraint.")
os.environ["VLLM_DT_CHUNK_LEN"] = \
str(scheduler_config.max_num_batched_tokens)

logger.info(
"Overriding configurations based on warmup shapes. "
Expand All @@ -221,10 +237,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
max(vllm_config.scheduler_config.max_num_seqs, 2))

# Hardcode some things for granite-3.3-8b-instruct
if cls.is_granite_3_8b(vllm_config.model_config):
cls.configure_granite_3_8b(vllm_config)

if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"):
# max product of batch size x tkv supported by the Spyre compiler
default_max_batch_tkv_limit = \
Expand Down Expand Up @@ -599,6 +611,15 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig):
vllm_config.cache_config.num_gpu_blocks_override,
blocks_override)

# hard-coded value for max_num_batched_tokens with chunked prefill
if envs_spyre.VLLM_SPYRE_USE_CHUNKED_PREFILL \
and envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn" \
and os.getenv("VLLM_DT_CHUNK_LEN") is None:
logger.info("Model granite-3.3-8b-instruct and tensor " \
"parallel size 4 with chunked prefill detected. Setting " \
"--max-num-batched-tokens 4096")
vllm_config.scheduler_config.max_num_batched_tokens = 4096

@classmethod
def is_granite_3_8b(cls, model_config: ModelConfig):
"""Returns true if we have a model that looks like
Expand Down