@@ -68,9 +68,9 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
6868}
6969
7070void sampling_from_logits (TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
71- bool deterministic,
72- Optional<TensorView> maybe_seed_arr, uint64_t seed_val ,
73- Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
71+ bool deterministic, Optional<TensorView> maybe_seed_arr,
72+ uint64_t seed_val, Optional<TensorView> maybe_offset_arr ,
73+ uint64_t offset_val) {
7474 CHECK_INPUT (logits);
7575 CHECK_DIM (2 , logits); // logits: (batch_size, vocab_size)
7676 CHECK_MAYBE_INPUT_TYPES (maybe_indices, dl_int32, dl_int64);
@@ -89,20 +89,20 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
8989 maybe_indices.has_value () ? static_cast <IdType*>(maybe_indices.value ().data_ptr ())
9090 : nullptr ,
9191 batch_size, vocab_size, deterministic,
92- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
92+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
93+ : nullptr ,
9394 seed_val,
94- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
95- offset_val ,
96- stream);
95+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
96+ : nullptr ,
97+ offset_val, stream);
9798 TVM_FFI_ICHECK (status == cudaSuccess)
9899 << " SamplingFromLogits failed with error code " << cudaGetErrorString (status);
99100 return true ;
100101 });
101102}
102103
103104void sampling_from_probs (TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
104- bool deterministic,
105- Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
105+ bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
106106 Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
107107 CHECK_INPUT (probs);
108108 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
@@ -122,11 +122,12 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
122122 maybe_indices.has_value () ? static_cast <IdType*>(maybe_indices.value ().data_ptr ())
123123 : nullptr ,
124124 batch_size, vocab_size, deterministic,
125- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
125+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
126+ : nullptr ,
126127 seed_val,
127- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
128- offset_val ,
129- stream);
128+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
129+ : nullptr ,
130+ offset_val, stream);
130131 TVM_FFI_ICHECK (status == cudaSuccess)
131132 << " SamplingFromProbs failed with error code " << cudaGetErrorString (status);
132133 return true ;
@@ -136,9 +137,9 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
136137void top_p_sampling_from_probs (TensorView probs, TensorView output,
137138 Optional<TensorView> maybe_indices,
138139 Optional<TensorView> maybe_top_p_arr, double top_p_val,
139- bool deterministic,
140- Optional<TensorView> maybe_seed_arr, uint64_t seed_val ,
141- Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
140+ bool deterministic, Optional<TensorView> maybe_seed_arr,
141+ uint64_t seed_val, Optional<TensorView> maybe_offset_arr ,
142+ uint64_t offset_val) {
142143 CHECK_INPUT (probs);
143144 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
144145 CHECK_MAYBE_INPUT_TYPES (maybe_indices, dl_int32, dl_int64);
@@ -160,11 +161,12 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
160161 : nullptr ,
161162 has_top_p_arr ? static_cast <float *>(maybe_top_p_arr.value ().data_ptr ()) : nullptr ,
162163 batch_size, top_p_val, vocab_size, deterministic,
163- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
164+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
165+ : nullptr ,
164166 seed_val,
165- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
166- offset_val ,
167- stream);
167+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
168+ : nullptr ,
169+ offset_val, stream);
168170 TVM_FFI_ICHECK (status == cudaSuccess)
169171 << " TopPSamplingFromProbs failed with error code " << cudaGetErrorString (status);
170172 return true ;
@@ -174,9 +176,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
174176void top_k_sampling_from_probs (TensorView probs, TensorView output,
175177 Optional<TensorView> maybe_indices,
176178 Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
177- bool deterministic,
178- Optional<TensorView> maybe_seed_arr, uint64_t seed_val ,
179- Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
179+ bool deterministic, Optional<TensorView> maybe_seed_arr,
180+ uint64_t seed_val, Optional<TensorView> maybe_offset_arr ,
181+ uint64_t offset_val) {
180182 CHECK_INPUT (probs);
181183 CHECK_INPUT (output);
182184 CHECK_DEVICE (output, probs);
@@ -201,11 +203,12 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
201203 : nullptr ,
202204 has_top_k_arr ? static_cast <float *>(maybe_top_k_arr.value ().data_ptr ()) : nullptr ,
203205 batch_size, top_k_val, vocab_size, deterministic,
204- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
206+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
207+ : nullptr ,
205208 seed_val,
206- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
207- offset_val ,
208- stream);
209+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
210+ : nullptr ,
211+ offset_val, stream);
209212 TVM_FFI_ICHECK (status == cudaSuccess)
210213 << " TopKSamplingFromProbs failed with error code " << cudaGetErrorString (status);
211214 return true ;
@@ -215,9 +218,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
215218void min_p_sampling_from_probs (TensorView probs, TensorView output,
216219 Optional<TensorView> maybe_indices,
217220 Optional<TensorView> maybe_min_p_arr, double min_p_val,
218- bool deterministic,
219- Optional<TensorView> maybe_seed_arr, uint64_t seed_val ,
220- Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
221+ bool deterministic, Optional<TensorView> maybe_seed_arr,
222+ uint64_t seed_val, Optional<TensorView> maybe_offset_arr ,
223+ uint64_t offset_val) {
221224 CHECK_INPUT (probs);
222225 CHECK_INPUT (output);
223226 CHECK_DEVICE (output, probs);
@@ -243,11 +246,12 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
243246 maybe_indices.has_value () ? static_cast <IdType*>(maybe_indices.value ().data_ptr ())
244247 : nullptr ,
245248 batch_size, min_p_val, vocab_size, deterministic,
246- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
249+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
250+ : nullptr ,
247251 seed_val,
248- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
249- offset_val ,
250- stream);
252+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
253+ : nullptr ,
254+ offset_val, stream);
251255 TVM_FFI_ICHECK (status == cudaSuccess)
252256 << " MinPSamplingFromProb failed with error code " << cudaGetErrorString (status);
253257 return true ;
@@ -258,9 +262,9 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
258262 Optional<TensorView> maybe_indices,
259263 Optional<TensorView> maybe_top_k_arr, double top_k_val,
260264 Optional<TensorView> maybe_top_p_arr, double top_p_val,
261- bool deterministic,
262- Optional<TensorView> maybe_seed_arr, uint64_t seed_val ,
263- Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
265+ bool deterministic, Optional<TensorView> maybe_seed_arr,
266+ uint64_t seed_val, Optional<TensorView> maybe_offset_arr ,
267+ uint64_t offset_val) {
264268 CHECK_INPUT (probs);
265269 CHECK_INPUT (output);
266270 CHECK_DEVICE (output, probs);
@@ -289,11 +293,12 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
289293 maybe_indices.has_value () ? static_cast <IdType*>(maybe_indices.value ().data_ptr ())
290294 : nullptr ,
291295 batch_size, top_k_val, top_p_val, vocab_size, deterministic,
292- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
296+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
297+ : nullptr ,
293298 seed_val,
294- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
295- offset_val ,
296- stream);
299+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
300+ : nullptr ,
301+ offset_val, stream);
297302 TVM_FFI_ICHECK (status == cudaSuccess)
298303 << " TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString (status);
299304 return true ;
@@ -334,11 +339,12 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i
334339 static_cast <int *>(output_accepted_token_num.data_ptr ()),
335340 static_cast <int *>(output_emitted_draft_token_num.data_ptr ()), batch_size,
336341 num_speculate_tokens, vocab_size, deterministic,
337- maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ()) : nullptr ,
342+ maybe_seed_arr.has_value () ? static_cast <uint64_t *>(maybe_seed_arr.value ().data_ptr ())
343+ : nullptr ,
338344 seed_val,
339- maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ()) : nullptr ,
340- offset_val ,
341- stream);
345+ maybe_offset_arr.has_value () ? static_cast <uint64_t *>(maybe_offset_arr.value ().data_ptr ())
346+ : nullptr ,
347+ offset_val, stream);
342348
343349 TVM_FFI_ICHECK (status == cudaSuccess)
344350 << " ChainSpeculativeSampling failed with error code " << cudaGetErrorString (status);
0 commit comments