Skip to content

Commit 61c5fb1

Browse files
committed
Remove const_cast
1 parent 8eae0ac commit 61c5fb1

File tree

5 files changed

+98
-99
lines changed

5 files changed

+98
-99
lines changed

paddle/phi/kernels/flash_attn_grad_kernel.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
3030
const DenseTensor& softmax_lse,
3131
const DenseTensor& seed_offset,
3232
const DenseTensor& dout,
33-
const int64_t max_seqlen_q,
34-
const int64_t max_seqlen_k,
35-
const float scale,
36-
const float dropout,
37-
const bool causal,
33+
int64_t max_seqlen_q,
34+
int64_t max_seqlen_k,
35+
float scale,
36+
float dropout,
37+
bool causal,
3838
DenseTensor* dq,
3939
DenseTensor* dk,
4040
DenseTensor* dv);
@@ -48,8 +48,8 @@ void FlashAttnGradKernel(const Context& ctx,
4848
const DenseTensor& softmax_lse,
4949
const DenseTensor& seed_offset,
5050
const DenseTensor& dout,
51-
const float dropout,
52-
const bool causal,
51+
float dropout,
52+
bool causal,
5353
DenseTensor* dq,
5454
DenseTensor* dk,
5555
DenseTensor* dv);

paddle/phi/kernels/flash_attn_kernel.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ void FlashAttnUnpaddedKernel(
2828
const DenseTensor& cu_seqlens_q,
2929
const DenseTensor& cu_seqlens_k,
3030
const paddle::optional<DenseTensor>& fixed_seed_offset,
31-
const int64_t max_seqlen_q,
32-
const int64_t max_seqlen_k,
33-
const float scale,
31+
int64_t max_seqlen_q,
32+
int64_t max_seqlen_k,
33+
float scale,
3434
float dropout,
35-
const bool causal,
36-
const bool return_softmax,
37-
const bool is_test,
35+
bool causal,
36+
bool return_softmax,
37+
bool is_test,
3838
const std::string& rng_name,
3939
DenseTensor* out,
4040
DenseTensor* softmax,
@@ -48,9 +48,9 @@ void FlashAttnKernel(const Context& ctx,
4848
const DenseTensor& v,
4949
const paddle::optional<DenseTensor>& fixed_seed_offset,
5050
float dropout,
51-
const bool causal,
52-
const bool return_softmax,
53-
const bool is_test,
51+
bool causal,
52+
bool return_softmax,
53+
bool is_test,
5454
const std::string& rng_name,
5555
DenseTensor* out,
5656
DenseTensor* softmax,

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 64 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
4242
const DenseTensor& softmax_lse,
4343
const DenseTensor& seed_offset,
4444
const DenseTensor& dout,
45-
const int64_t max_seqlen_q,
46-
const int64_t max_seqlen_k,
47-
const float scale,
48-
const float dropout,
49-
const bool causal,
45+
int64_t max_seqlen_q,
46+
int64_t max_seqlen_k,
47+
float scale,
48+
float dropout,
49+
bool causal,
5050
DenseTensor* dq,
5151
DenseTensor* dk,
5252
DenseTensor* dv) {
@@ -94,36 +94,36 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
9494
DenseTensor dq_accum = Empty<float>(
9595
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
9696

97-
const bool succ = phi::dynload::flash_attn_varlen_bwd(
98-
const_cast<void*>(dout.data()),
99-
const_cast<void*>(q.data()),
100-
const_cast<void*>(k.data()),
101-
const_cast<void*>(v.data()),
102-
const_cast<void*>(out.data()),
103-
softmax_d.data(),
104-
const_cast<void*>(softmax_lse.data()),
105-
dq->data(),
106-
dk->data(),
107-
dv->data(),
108-
dq_accum.data(),
109-
const_cast<int32_t*>(cu_seqlens_q.data<int32_t>()),
110-
const_cast<int32_t*>(cu_seqlens_k.data<int32_t>()),
111-
batch_size,
112-
max_seqlen_q,
113-
max_seqlen_k,
114-
seqlen_q_rounded,
115-
seqlen_k_rounded,
116-
num_heads,
117-
num_heads_k,
118-
head_size,
119-
head_size_rounded,
120-
dropout,
121-
scale,
122-
causal,
123-
is_bf16,
124-
stream,
125-
seed,
126-
offset);
97+
const bool succ =
98+
phi::dynload::flash_attn_varlen_bwd(dout.data(),
99+
q.data(),
100+
k.data(),
101+
v.data(),
102+
out.data(),
103+
softmax_d.data(),
104+
softmax_lse.data(),
105+
dq->data(),
106+
dk->data(),
107+
dv->data(),
108+
dq_accum.data(),
109+
cu_seqlens_q.data<int32_t>(),
110+
cu_seqlens_k.data<int32_t>(),
111+
batch_size,
112+
max_seqlen_q,
113+
max_seqlen_k,
114+
seqlen_q_rounded,
115+
seqlen_k_rounded,
116+
num_heads,
117+
num_heads_k,
118+
head_size,
119+
head_size_rounded,
120+
dropout,
121+
scale,
122+
causal,
123+
is_bf16,
124+
stream,
125+
seed,
126+
offset);
127127

128128
if (!succ) {
129129
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
@@ -141,8 +141,8 @@ void FlashAttnGradKernel(const Context& ctx,
141141
const DenseTensor& softmax_lse,
142142
const DenseTensor& seed_offset,
143143
const DenseTensor& dout,
144-
const float dropout,
145-
const bool causal,
144+
float dropout,
145+
bool causal,
146146
DenseTensor* dq,
147147
DenseTensor* dk,
148148
DenseTensor* dv) {
@@ -193,34 +193,33 @@ void FlashAttnGradKernel(const Context& ctx,
193193

194194
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset;
195195

196-
const bool succ =
197-
phi::dynload::flash_attn_bwd(const_cast<void*>(dout.data()),
198-
const_cast<void*>(q.data()),
199-
const_cast<void*>(k.data()),
200-
const_cast<void*>(v.data()),
201-
const_cast<void*>(out.data()),
202-
softmax_d.data(),
203-
const_cast<void*>(softmax_lse.data()),
204-
dq->data(),
205-
dk->data(),
206-
dv->data(),
207-
dq_accum.data(),
208-
batch_size,
209-
seqlen_q,
210-
seqlen_k,
211-
seqlen_q_rounded,
212-
seqlen_k_rounded,
213-
num_heads,
214-
num_heads_k,
215-
head_size,
216-
head_size_rounded,
217-
dropout,
218-
scale,
219-
causal,
220-
is_bf16,
221-
stream,
222-
seed,
223-
offset);
196+
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
197+
q.data(),
198+
k.data(),
199+
v.data(),
200+
out.data(),
201+
softmax_d.data(),
202+
softmax_lse.data(),
203+
dq->data(),
204+
dk->data(),
205+
dv->data(),
206+
dq_accum.data(),
207+
batch_size,
208+
seqlen_q,
209+
seqlen_k,
210+
seqlen_q_rounded,
211+
seqlen_k_rounded,
212+
num_heads,
213+
num_heads_k,
214+
head_size,
215+
head_size_rounded,
216+
dropout,
217+
scale,
218+
causal,
219+
is_bf16,
220+
stream,
221+
seed,
222+
offset);
224223

225224
if (!succ) {
226225
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));

paddle/phi/kernels/gpu/flash_attn_kernel.cu

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ void FlashAttnUnpaddedKernel(
4242
const DenseTensor& cu_seqlens_q,
4343
const DenseTensor& cu_seqlens_k,
4444
const paddle::optional<DenseTensor>& fixed_seed_offset,
45-
const int64_t max_seqlen_q,
46-
const int64_t max_seqlen_k,
47-
const float scale,
45+
int64_t max_seqlen_q,
46+
int64_t max_seqlen_k,
47+
float scale,
4848
float dropout,
49-
const bool causal,
50-
const bool return_softmax,
51-
const bool is_test,
49+
bool causal,
50+
bool return_softmax,
51+
bool is_test,
5252
const std::string& rng_name,
5353
DenseTensor* out,
5454
DenseTensor* softmax,
@@ -129,12 +129,12 @@ void FlashAttnUnpaddedKernel(
129129
}
130130

131131
const bool succ = phi::dynload::flash_attn_varlen_fwd(
132-
const_cast<void*>(q.data()),
133-
const_cast<void*>(k.data()),
134-
const_cast<void*>(v.data()),
132+
q.data(),
133+
k.data(),
134+
v.data(),
135135
out->data(),
136-
const_cast<void*>(cu_seqlens_q.data()),
137-
const_cast<void*>(cu_seqlens_k.data()),
136+
cu_seqlens_q.data<int32_t>(),
137+
cu_seqlens_k.data<int32_t>(),
138138
return_softmax ? softmax->data() : nullptr,
139139
softmax_lse->data(),
140140
batch_size,
@@ -169,9 +169,9 @@ void FlashAttnKernel(const Context& ctx,
169169
const DenseTensor& v,
170170
const paddle::optional<DenseTensor>& fixed_seed_offset,
171171
float dropout,
172-
const bool causal,
173-
const bool return_softmax,
174-
const bool is_test,
172+
bool causal,
173+
bool return_softmax,
174+
bool is_test,
175175
const std::string& rng_name,
176176
DenseTensor* out,
177177
DenseTensor* softmax,
@@ -253,9 +253,9 @@ void FlashAttnKernel(const Context& ctx,
253253
}
254254

255255
bool succ =
256-
phi::dynload::flash_attn_fwd(const_cast<void*>(q.data()),
257-
const_cast<void*>(k.data()),
258-
const_cast<void*>(v.data()),
256+
phi::dynload::flash_attn_fwd(q.data(),
257+
k.data(),
258+
v.data(),
259259
out->data(),
260260
return_softmax ? softmax->data() : nullptr,
261261
softmax_lse->data(),

0 commit comments

Comments
 (0)