Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 95 additions & 88 deletions csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,40 +102,21 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cum_offsets,
const int group_size,
const int block_size,
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
paddle::Tensor encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_x_cpu,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, decoder_batch_ids,
decoder_tile_ids_per_batch, decoder_num_blocks_x_cpu;
auto stream = seq_lens_this_time.stream();
int bsz = cum_offsets.shape()[0];
const int encoder_block_shape_q = get_encoder_block_shape_q();
const int decoder_block_shape_q = get_decoder_block_shape_q();

// decoder
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
auto decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
auto decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);

// max_len
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
Expand All @@ -147,77 +128,100 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
auto max_len_kv_cpu =
max_len_kv.copy_to(paddle::CPUPlace(), false);

// decoder
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
if (max_dec_len_this_time_data > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_tile_ids_per_batch =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_num_blocks_x_cpu =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace());
}

// encoder
int max_enc_len_this_time_data = max_enc_len_this_time.data<int>()[0];
if (max_enc_len_this_time_data <= 0) {
auto encoder_batch_ids =
if (max_enc_len_this_time_data > 0) {
const uint32_t encoder_max_tile_size_per_bs_q = div_up(
(max_enc_len_this_time_data * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(),
bsz,
encoder_block_shape_q,
group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);

// kv
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_len_this_time_data, block_size);
kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_kv_block<<<1, 32, 0, stream>>>(seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size);
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto encoder_tile_ids_per_batch =
encoder_tile_ids_per_batch =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto encoder_num_blocks_x_cpu =
encoder_num_blocks_x_cpu =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace());
auto kv_batch_ids =
kv_batch_ids =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto kv_tile_ids_per_batch =
kv_tile_ids_per_batch =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto kv_num_blocks_x_cpu =
kv_num_blocks_x_cpu =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace());

return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/};
}

// encoder
const uint32_t encoder_max_tile_size_per_bs_q = div_up(
(max_enc_len_this_time_data * group_size), encoder_block_shape_q);
auto encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(),
bsz,
encoder_block_shape_q,
group_size);
auto encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);

// kv
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_len_this_time_data, block_size);
auto kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_kv_block<<<1, 32, 0, stream>>>(seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size);
auto kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
Expand All @@ -234,6 +238,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& max_enc_len_this_time_dtype,
const paddle::DataType& max_dec_len_this_time_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cum_offsets_dtype) {
return {paddle::DataType::INT32,
Expand All @@ -252,6 +257,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& max_enc_len_this_time_shape,
const std::vector<int64_t>& max_dec_len_this_time_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cum_offsets_shape) {
std::vector<int64_t> dynamic_shape = {-1};
Expand All @@ -272,6 +278,7 @@ PD_BUILD_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"max_enc_len_this_time",
"max_dec_len_this_time",
"seq_lens_this_time",
"cum_offsets"})
.Outputs({"encoder_batch_ids",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ def forward(
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
max_enc_len_this_time,
max_dec_len_this_time,
kwargs.get("seq_lens_this_time", None),
kwargs.get("cum_offsets", None),
self.num_heads // self.kv_num_heads,
Expand Down