Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,14 @@ def _get_and_verify_max_len(
# derived length from the HF model config.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
if current_platform.is_tpu():
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %s, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.", max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
Expand Down
32 changes: 29 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,8 +1432,8 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try:
from vllm.platforms import current_platform
device_memory = current_platform.get_device_total_memory()
except Exception:
# This is only used to set default_max_num_batched_tokens
Expand All @@ -1454,11 +1454,37 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
}
default_max_num_seqs = 256

# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
'V6E': 2048,
'V5E': 1024,
'V5P': 512,
},
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
Comment on lines +1465 to +1469
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do still worry about these smaller sizes for multimodal models, since this is smaller than the usual image size in tokens, which will cause errors for the user. Maybe we can expand this for multimodal in a separate pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, for multi-modal, is it required that we put all image tokens within 1 batch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to be able to put the largest single mm item in a single token batch. If we split the mm item, then we need slices that force recompilation on TPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the explanation.

}

use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[
usage_context]:
self.max_num_batched_tokens = \
default_max_num_batched_tokens_tpu[
usage_context][chip_name]
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value)
Expand Down
4 changes: 3 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Optional, Union

import torch
from tpu_info import device

import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType
Expand Down Expand Up @@ -54,7 +55,8 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "tpu"
chip_type, _ = device.get_local_chips()
return f"TPU {chip_type.name}"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
Expand Down