Skip to content

Commit f6fb119

Browse files
Restore int64 sampling (vllm-project#35)
1 parent 14d294d commit f6fb119

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/model_executor/sampling_metadata.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.sampling_params import SamplingParams, SamplingType
99
from vllm.sequence import SequenceData, SequenceGroupMetadata
1010
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
11-
maybe_expand_dim, is_hpu)
11+
maybe_expand_dim)
1212

1313
_SAMPLING_EPS = 1e-5
1414
_SEED_0_REPLACEMENT = 3403598558
@@ -501,19 +501,19 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
501501
sample_indices_t = torch.tensor(
502502
sample_indices,
503503
device="cpu",
504-
dtype=torch.int,
504+
dtype=torch.long,
505505
pin_memory=pin_memory,
506506
)
507507
prompt_tensor = torch.tensor(
508508
prompt_padded_tokens,
509509
device="cpu",
510-
dtype=torch.int,
510+
dtype=torch.long,
511511
pin_memory=pin_memory,
512512
)
513513
output_tensor = torch.tensor(
514514
output_padded_tokens,
515515
device="cpu",
516-
dtype=torch.int,
516+
dtype=torch.long,
517517
pin_memory=pin_memory,
518518
)
519519
# need to transpose and make contiguous to
@@ -522,7 +522,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
522522
sampling_seeds_t = torch.tensor(
523523
sampling_seeds,
524524
device="cpu",
525-
dtype=torch.int,
525+
dtype=torch.long,
526526
pin_memory=pin_memory,
527527
).T.contiguous()
528528

@@ -571,7 +571,7 @@ def _get_sequence_seeds(
571571
else:
572572
generator = random.Random(str((seed, ) + extra_entropy))
573573
randint_fn = generator.randint
574-
lo, hi = torch.iinfo(torch.int).min, torch.iinfo(torch.int).max
574+
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
575575
# If the user/random sets seed = 0 but request should
576576
# have sampling, we need to change it to something
577577
# else. We use a constant in that case.

0 commit comments

Comments
 (0)