Skip to content

Commit 4ae57e6

Browse files
yzh119IzzyPutterman
authored andcommitted
upd
1 parent 0b546c7 commit 4ae57e6

File tree

4 files changed

+231
-133
lines changed

4 files changed

+231
-133
lines changed

csrc/flashinfer_sampling_binding.cu

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,42 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
2121
Optional<TensorView> maybe_temperature_arr, double temperature_val, bool enable_pdl);
2222

2323
void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
24-
bool deterministic,
25-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
24+
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
2625
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
2726

2827
void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
29-
bool deterministic,
30-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
31-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
28+
bool deterministic, Optional<TensorView> maybe_seed_arr,
29+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
30+
uint64_t offset_val);
3231

3332
void top_p_sampling_from_probs(TensorView probs, TensorView output,
3433
Optional<TensorView> maybe_indices,
3534
Optional<TensorView> maybe_top_p_arr, double top_p_val,
36-
bool deterministic,
37-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
38-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
35+
bool deterministic, Optional<TensorView> maybe_seed_arr,
36+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
37+
uint64_t offset_val);
3938

4039
void top_k_sampling_from_probs(TensorView probs, TensorView output,
4140
Optional<TensorView> maybe_indices,
4241
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
43-
bool deterministic,
44-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
45-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
42+
bool deterministic, Optional<TensorView> maybe_seed_arr,
43+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
44+
uint64_t offset_val);
4645

4746
void min_p_sampling_from_probs(TensorView probs, TensorView output,
4847
Optional<TensorView> maybe_indices,
4948
Optional<TensorView> maybe_min_p_arr, double min_p_val,
50-
bool deterministic,
51-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
52-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
49+
bool deterministic, Optional<TensorView> maybe_seed_arr,
50+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
51+
uint64_t offset_val);
5352

5453
void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
5554
Optional<TensorView> maybe_indices,
5655
Optional<TensorView> maybe_top_k_arr, double top_k_val,
5756
Optional<TensorView> maybe_top_p_arr, double top_p_val,
58-
bool deterministic,
59-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
60-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);
57+
bool deterministic, Optional<TensorView> maybe_seed_arr,
58+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
59+
uint64_t offset_val);
6160

6261
void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
6362
Optional<TensorView> maybe_top_p_arr, double top_p_val);

csrc/sampling.cu

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
6868
}
6969

7070
void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
71-
bool deterministic,
72-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
73-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
71+
bool deterministic, Optional<TensorView> maybe_seed_arr,
72+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
73+
uint64_t offset_val) {
7474
CHECK_INPUT(logits);
7575
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
7676
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
@@ -89,20 +89,20 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
8989
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
9090
: nullptr,
9191
batch_size, vocab_size, deterministic,
92-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
92+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
93+
: nullptr,
9394
seed_val,
94-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
95-
offset_val,
96-
stream);
95+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
96+
: nullptr,
97+
offset_val, stream);
9798
TVM_FFI_ICHECK(status == cudaSuccess)
9899
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
99100
return true;
100101
});
101102
}
102103

103104
void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
104-
bool deterministic,
105-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
105+
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
106106
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
107107
CHECK_INPUT(probs);
108108
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
@@ -122,11 +122,12 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
122122
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
123123
: nullptr,
124124
batch_size, vocab_size, deterministic,
125-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
125+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
126+
: nullptr,
126127
seed_val,
127-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
128-
offset_val,
129-
stream);
128+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
129+
: nullptr,
130+
offset_val, stream);
130131
TVM_FFI_ICHECK(status == cudaSuccess)
131132
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
132133
return true;
@@ -136,9 +137,9 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
136137
void top_p_sampling_from_probs(TensorView probs, TensorView output,
137138
Optional<TensorView> maybe_indices,
138139
Optional<TensorView> maybe_top_p_arr, double top_p_val,
139-
bool deterministic,
140-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
141-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
140+
bool deterministic, Optional<TensorView> maybe_seed_arr,
141+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
142+
uint64_t offset_val) {
142143
CHECK_INPUT(probs);
143144
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
144145
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
@@ -160,11 +161,12 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
160161
: nullptr,
161162
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
162163
batch_size, top_p_val, vocab_size, deterministic,
163-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
164+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
165+
: nullptr,
164166
seed_val,
165-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
166-
offset_val,
167-
stream);
167+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
168+
: nullptr,
169+
offset_val, stream);
168170
TVM_FFI_ICHECK(status == cudaSuccess)
169171
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
170172
return true;
@@ -174,9 +176,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
174176
void top_k_sampling_from_probs(TensorView probs, TensorView output,
175177
Optional<TensorView> maybe_indices,
176178
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
177-
bool deterministic,
178-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
179-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
179+
bool deterministic, Optional<TensorView> maybe_seed_arr,
180+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
181+
uint64_t offset_val) {
180182
CHECK_INPUT(probs);
181183
CHECK_INPUT(output);
182184
CHECK_DEVICE(output, probs);
@@ -201,11 +203,12 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
201203
: nullptr,
202204
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
203205
batch_size, top_k_val, vocab_size, deterministic,
204-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
206+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
207+
: nullptr,
205208
seed_val,
206-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
207-
offset_val,
208-
stream);
209+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
210+
: nullptr,
211+
offset_val, stream);
209212
TVM_FFI_ICHECK(status == cudaSuccess)
210213
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
211214
return true;
@@ -215,9 +218,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
215218
void min_p_sampling_from_probs(TensorView probs, TensorView output,
216219
Optional<TensorView> maybe_indices,
217220
Optional<TensorView> maybe_min_p_arr, double min_p_val,
218-
bool deterministic,
219-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
220-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
221+
bool deterministic, Optional<TensorView> maybe_seed_arr,
222+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
223+
uint64_t offset_val) {
221224
CHECK_INPUT(probs);
222225
CHECK_INPUT(output);
223226
CHECK_DEVICE(output, probs);
@@ -243,11 +246,12 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
243246
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
244247
: nullptr,
245248
batch_size, min_p_val, vocab_size, deterministic,
246-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
249+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
250+
: nullptr,
247251
seed_val,
248-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
249-
offset_val,
250-
stream);
252+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
253+
: nullptr,
254+
offset_val, stream);
251255
TVM_FFI_ICHECK(status == cudaSuccess)
252256
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
253257
return true;
@@ -258,9 +262,9 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
258262
Optional<TensorView> maybe_indices,
259263
Optional<TensorView> maybe_top_k_arr, double top_k_val,
260264
Optional<TensorView> maybe_top_p_arr, double top_p_val,
261-
bool deterministic,
262-
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
263-
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
265+
bool deterministic, Optional<TensorView> maybe_seed_arr,
266+
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
267+
uint64_t offset_val) {
264268
CHECK_INPUT(probs);
265269
CHECK_INPUT(output);
266270
CHECK_DEVICE(output, probs);
@@ -289,11 +293,12 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
289293
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
290294
: nullptr,
291295
batch_size, top_k_val, top_p_val, vocab_size, deterministic,
292-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
296+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
297+
: nullptr,
293298
seed_val,
294-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
295-
offset_val,
296-
stream);
299+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
300+
: nullptr,
301+
offset_val, stream);
297302
TVM_FFI_ICHECK(status == cudaSuccess)
298303
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
299304
return true;
@@ -334,11 +339,12 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i
334339
static_cast<int*>(output_accepted_token_num.data_ptr()),
335340
static_cast<int*>(output_emitted_draft_token_num.data_ptr()), batch_size,
336341
num_speculate_tokens, vocab_size, deterministic,
337-
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr()) : nullptr,
342+
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
343+
: nullptr,
338344
seed_val,
339-
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr()) : nullptr,
340-
offset_val,
341-
stream);
345+
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
346+
: nullptr,
347+
offset_val, stream);
342348

343349
TVM_FFI_ICHECK(status == cudaSuccess)
344350
<< "ChainSpeculativeSampling failed with error code " << cudaGetErrorString(status);

0 commit comments

Comments
 (0)