Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.platforms import current_platform


class BaseLogitsProcessor:

Expand Down Expand Up @@ -91,7 +93,14 @@ def __call__(self, input_ids: List[int],
allowed_tokens = allowed_tokens.masked_select(
allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)
scores.add_(mask)
if current_platform.is_hpu():
# Workaround for HPU bug where add_() raise RuntimeError:
# synNodeCreateWithId failed for node: strided_insert
# with synStatus 1 [Invalid argument], hopefully it will
# be fixed in the future releases of the HPU runtime.
scores = scores.add(mask)
else:
scores.add_(mask)
return scores


Expand Down