1+ import os
12from dataclasses import dataclass
23from importlib .util import find_spec
34from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
45
56import torch
67from torch import nn
8+ from transformers_neuronx .config import GenerationConfig
79
810from vllm .config import (DeviceConfig , ModelConfig , ParallelConfig ,
911 SchedulerConfig )
@@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict(
5052
5153class NeuronModelRunner (ModelRunnerBase [ModelInputForNeuron ]):
5254
55+ # NEURON has an upper limit on the top_k
56+ _MAX_NEURON_SAMPLING_TOP_K = 256
57+
5358 def __init__ (
5459 self ,
5560 model_config : ModelConfig ,
@@ -76,6 +81,34 @@ def __init__(
7681 # Lazy initialization.
7782 self .model : nn .Module # initialize after load_model.
7883
84+ # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
85+ # turn off on-device sampling.
86+ self ._on_device_sampling_disabled = int (
87+ os .getenv ("NEURON_ON_DEVICE_SAMPLING_DISABLED" , "0" ))
88+
89+ # NEURON needs to update sampling parameters when request IDs change
90+ # across batches. This variable stores the previous batch's request IDs
91+ # to determine if an update is needed.
92+ self ._previous_batch_request_ids : List [str ] = []
93+
94+ if not self ._on_device_sampling_disabled :
95+ logger .warning (
96+ "On-device sampling is turned on in Neuron by default, only "
97+ "top_k, top_p, and temperature are current supported sampling "
98+ "parameters. To turn off the on-device sampling, please set "
99+ "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
100+ )
101+ self .model_config .neuron_sampling_params = GenerationConfig (
102+ max_length = self .scheduler_config .max_model_len ,
103+ do_sample = True ,
104+ per_batch_line = True ,
105+ top_k = [self ._MAX_NEURON_SAMPLING_TOP_K ] \
106+ * self .scheduler_config .max_num_seqs ,
107+ top_p = [1.0 ] * self .scheduler_config .max_num_seqs ,
108+ temperature = [1.0 ] * self .scheduler_config .max_num_seqs ,
109+ dynamic = True ,
110+ global_top_k = self ._MAX_NEURON_SAMPLING_TOP_K )
111+
79112 def load_model (self ) -> None :
80113 if find_spec ("transformers_neuronx" ) is not None :
81114 self .model = get_neuron_model (
@@ -215,7 +248,7 @@ def prepare_model_input(
215248 else :
216249 (input_tokens , input_positions ,
217250 input_block_ids ) = self ._prepare_decode (seq_group_metadata_list )
218- seq_lens = []
251+ seq_lens = None
219252 sampling_metadata = SamplingMetadata .prepare (
220253 seq_group_metadata_list ,
221254 seq_lens ,
@@ -227,12 +260,49 @@ def prepare_model_input(
227260 self .pin_memory ,
228261 generators = self .get_generators (finished_requests_ids ))
229262
263+ if not self ._on_device_sampling_disabled :
264+ # Once the request IDs are changed in current iteration, we will
265+ # update the on-device sampling parameters.
266+ current_batch_request_ids = [
267+ seq_group_meta_data .request_id
268+ for seq_group_meta_data in seq_group_metadata_list
269+ ]
270+ if current_batch_request_ids != self ._previous_batch_request_ids :
271+ self ._update_neuron_sampling_params (sampling_metadata )
272+ self ._previous_batch_request_ids = current_batch_request_ids
273+
230274 return ModelInputForNeuron (input_tokens = input_tokens ,
231275 input_positions = input_positions ,
232276 input_block_ids = input_block_ids ,
233277 sampling_metadata = sampling_metadata ,
234278 multi_modal_kwargs = multi_modal_kwargs )
235279
280+ def _update_neuron_sampling_params (self ,
281+ sampling_metadata : SamplingMetadata ):
282+ # Update Neuron sampling parameters (GenerationConfig in Neuron)
283+ current_sampling_params = self .model_config .neuron_sampling_params
284+ assert current_sampling_params is not None , (
285+ f"Failed to update sampling_params, "
286+ f"current sampling params is { current_sampling_params } " )
287+
288+ top_k = current_sampling_params .top_k
289+ top_p = current_sampling_params .top_p
290+ temperature = current_sampling_params .temperature
291+ for index , sequence_group_to_sample in enumerate (
292+ sampling_metadata .seq_groups ):
293+ top_k [index ] = self ._convert_to_neuron_top_k (
294+ sequence_group_to_sample .sampling_params .top_k )
295+ top_p [index ] = sequence_group_to_sample .sampling_params .top_p
296+ temperature [index ] = \
297+ sequence_group_to_sample .sampling_params .temperature
298+
299+ self .model .model .update_generation_config (current_sampling_params )
300+
301+ def _convert_to_neuron_top_k (self , top_k : int ) -> int :
302+ if top_k < 0 or top_k > self ._MAX_NEURON_SAMPLING_TOP_K :
303+ return self ._MAX_NEURON_SAMPLING_TOP_K
304+ return top_k
305+
236306 @torch .inference_mode ()
237307 def execute_model (
238308 self ,
@@ -253,9 +323,13 @@ def execute_model(
253323 device = self .device ),
254324 )
255325
256- # Compute the logits.
257- logits = self .model .compute_logits (hidden_states ,
258- model_input .sampling_metadata )
326+ # Compute the logits only if the on-device sampling is turned off as
327+ # on-device sampling outputs the token ids.
328+ if self ._on_device_sampling_disabled :
329+ logits = self .model .compute_logits (hidden_states ,
330+ model_input .sampling_metadata )
331+ else :
332+ logits = hidden_states
259333
260334 # Sample the next token.
261335 output = self .model .sample (
0 commit comments