Skip to content

Commit c0ad271

Browse files
authored
[INTEL_HPU] Add FP8 sdpa custom op (PaddlePaddle#1710)
Signed-off-by: Fei Wang <[email protected]>
1 parent cd04840 commit c0ad271

File tree

2 files changed

+487
-0
lines changed

2 files changed

+487
-0
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "habanalabs/synapse_api.h"
17+
#include "habanalabs/synapse_common_types.h"
18+
#include "kernels/funcs.h"
19+
#include "kernels/hpu_operator.h"
20+
#include "paddle/extension.h"
21+
#include "utils/utils.h"
22+
23+
#define SDPA_SET_INPUT_AND_FLAGS(ptr, flag_name) \
24+
if (ptr) { \
25+
flags |= SdpaFlags_t::SDPA_FLAGS_##flag_name; \
26+
ct.Add(ptr); \
27+
}
28+
29+
namespace custom_kernel {
30+
31+
struct SDPAParams {
32+
bool has_atten_mask;
33+
ns_Sdpa::ParamsV3 params;
34+
};
35+
36+
class FusedFp8Sdpa : public HpuOperator {
37+
public:
38+
FusedFp8Sdpa() : HpuOperator("sdpa_recomp_fwd_hf8") {}
39+
void AddNode(ConvertTensors& ct, SDPAParams& params) {
40+
auto inputs = ct.GetTensors();
41+
auto outputs = ct.GetTensors(false);
42+
43+
std::vector<synTensor> sync_inputs;
44+
synStatus status = synFail;
45+
for (size_t i = 0; i < 3; i++) {
46+
sync_inputs.push_back(createTensor(inputs[i].dims.size(),
47+
inputs[i].type,
48+
inputs[i].dims,
49+
true,
50+
inputs[i].name));
51+
}
52+
53+
// atten mask
54+
if (!params.has_atten_mask) {
55+
sync_inputs.push_back(nullptr);
56+
}
57+
58+
// seed
59+
sync_inputs.push_back(nullptr);
60+
61+
for (size_t i = 3; i < inputs.size(); i++) {
62+
sync_inputs.push_back(createTensor(inputs[i].dims.size(),
63+
inputs[i].type,
64+
inputs[i].dims,
65+
true,
66+
inputs[i].name));
67+
}
68+
69+
std::vector<synTensor> sync_outputs;
70+
for (size_t i = 0; i < outputs.size(); i++) {
71+
sync_outputs.push_back(createTensor(outputs[i].dims.size(),
72+
outputs[i].type,
73+
outputs[i].dims,
74+
true,
75+
outputs[i].name));
76+
}
77+
78+
status = synNodeCreate(graphHandle_,
79+
sync_inputs.data(),
80+
sync_outputs.data(),
81+
sync_inputs.size(),
82+
sync_outputs.size(),
83+
&params.params,
84+
sizeof(params.params),
85+
guid_.c_str(),
86+
guid_.c_str(),
87+
nullptr,
88+
nullptr);
89+
PD_CHECK(
90+
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
91+
}
92+
};
93+
94+
template <typename T, typename Context>
95+
void fused_fp8_sdpa(const Context& dev_ctx,
96+
const phi::DenseTensor& q,
97+
const phi::DenseTensor& k,
98+
const phi::DenseTensor& v,
99+
const paddle::optional<phi::DenseTensor>& attn_mask,
100+
const paddle::optional<phi::DenseTensor>& d_scale_q,
101+
const paddle::optional<phi::DenseTensor>& d_scale_k,
102+
const paddle::optional<phi::DenseTensor>& d_scale_v,
103+
const paddle::optional<phi::DenseTensor>& q_scale_s,
104+
const paddle::optional<phi::DenseTensor>& q_scale_o,
105+
const paddle::optional<phi::DenseTensor>& d_scale_s,
106+
float scale,
107+
bool causal,
108+
phi::DenseTensor* out) {
109+
// allocate memory on device.
110+
dev_ctx.template Alloc<T>(out);
111+
if (out->numel() == 0) {
112+
return;
113+
}
114+
115+
ConvertTensors ct;
116+
ct.Add(q);
117+
ct.Add(k);
118+
ct.Add(v);
119+
120+
unsigned int flags = 0;
121+
122+
SDPA_SET_INPUT_AND_FLAGS(d_scale_q.get_ptr(), D_SCALE_Q)
123+
SDPA_SET_INPUT_AND_FLAGS(d_scale_k.get_ptr(), D_SCALE_K)
124+
SDPA_SET_INPUT_AND_FLAGS(d_scale_v.get_ptr(), D_SCALE_V)
125+
SDPA_SET_INPUT_AND_FLAGS(q_scale_s.get_ptr(), Q_SCALE_S)
126+
SDPA_SET_INPUT_AND_FLAGS(q_scale_o.get_ptr(), Q_SCALE_O)
127+
SDPA_SET_INPUT_AND_FLAGS(d_scale_s.get_ptr(), D_SCALE_S)
128+
129+
SDPAParams params{};
130+
131+
if (attn_mask.get_ptr()) {
132+
ct.Add(attn_mask.get_ptr());
133+
params.has_atten_mask = true;
134+
}
135+
136+
params.params.scale = scale;
137+
params.params.is_causal = causal;
138+
params.params.dropout.ratio = 0;
139+
params.params.is_inference = true;
140+
params.params.softmax_mode = SDPA_DEFAULT_SOFTMAX;
141+
params.params.flags = flags;
142+
143+
ct.Add(*out, false);
144+
std::vector<DIMS> inputs_dims = ct.GetDims();
145+
146+
OpCacheOperator op_info;
147+
op_info.prepareOpInfo<T, SDPAParams>(
148+
"FusedFp8SdpaKernel", inputs_dims, &params);
149+
auto recipe = op_info.GetRecipe();
150+
151+
if (recipe == nullptr) {
152+
FusedFp8Sdpa op;
153+
op.AddNode(ct, params);
154+
op.Compile();
155+
op_info.setOp(op);
156+
recipe = op_info.GetRecipe();
157+
}
158+
159+
auto tensors = ct.GetDeviceAddr();
160+
RecipeRunner runner(recipe);
161+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
162+
}
163+
164+
} // namespace custom_kernel
165+
166+
std::vector<paddle::Tensor> FusedFp8SdpaForward(
167+
const paddle::Tensor& q,
168+
const paddle::Tensor& k,
169+
const paddle::Tensor& v,
170+
const paddle::optional<paddle::Tensor>& attn_mask,
171+
const paddle::optional<paddle::Tensor>& d_scale_q,
172+
const paddle::optional<paddle::Tensor>& d_scale_k,
173+
const paddle::optional<paddle::Tensor>& d_scale_v,
174+
const paddle::optional<paddle::Tensor>& q_scale_s,
175+
const paddle::optional<paddle::Tensor>& q_scale_o,
176+
const paddle::optional<paddle::Tensor>& d_scale_s,
177+
bool causal,
178+
float scale) {
179+
auto dev_ctx = static_cast<const phi::CustomContext*>(
180+
paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
181+
182+
auto q_tensor = static_cast<const phi::DenseTensor*>(q.impl().get());
183+
auto k_tensor = static_cast<const phi::DenseTensor*>(k.impl().get());
184+
auto v_tensor = static_cast<const phi::DenseTensor*>(v.impl().get());
185+
186+
// attn_mask
187+
phi::DenseTensor* attn_mask_tensor = nullptr;
188+
if (attn_mask) {
189+
auto attn_mask_ptr = *(attn_mask.get_ptr());
190+
attn_mask_tensor =
191+
static_cast<phi::DenseTensor*>(attn_mask_ptr.impl().get());
192+
}
193+
194+
// s_scale_q
195+
phi::DenseTensor* d_scale_q_tensor = nullptr;
196+
if (d_scale_q) {
197+
auto d_scale_q_ptr = *(d_scale_q.get_ptr());
198+
d_scale_q_tensor =
199+
static_cast<phi::DenseTensor*>(d_scale_q_ptr.impl().get());
200+
}
201+
202+
// d_scale_k
203+
phi::DenseTensor* d_scale_k_tensor = nullptr;
204+
if (d_scale_k) {
205+
auto d_scale_k_ptr = *(d_scale_k.get_ptr());
206+
d_scale_k_tensor =
207+
static_cast<phi::DenseTensor*>(d_scale_k_ptr.impl().get());
208+
}
209+
210+
// d_scale_v
211+
phi::DenseTensor* d_scale_v_tensor = nullptr;
212+
if (d_scale_v) {
213+
auto d_scale_v_ptr = *(d_scale_v.get_ptr());
214+
d_scale_v_tensor =
215+
static_cast<phi::DenseTensor*>(d_scale_v_ptr.impl().get());
216+
}
217+
218+
// q_scale_s
219+
phi::DenseTensor* q_scale_s_tensor = nullptr;
220+
if (q_scale_s) {
221+
auto q_scale_s_ptr = *(q_scale_s.get_ptr());
222+
q_scale_s_tensor =
223+
static_cast<phi::DenseTensor*>(q_scale_s_ptr.impl().get());
224+
}
225+
226+
// q_scale_o
227+
phi::DenseTensor* q_scale_o_tensor = nullptr;
228+
if (q_scale_o) {
229+
auto q_scale_o_ptr = *(q_scale_o.get_ptr());
230+
q_scale_o_tensor =
231+
static_cast<phi::DenseTensor*>(q_scale_o_ptr.impl().get());
232+
}
233+
234+
// d_scale_s
235+
phi::DenseTensor* d_scale_s_tensor = nullptr;
236+
if (d_scale_s) {
237+
auto d_scale_s_ptr = *(d_scale_s.get_ptr());
238+
d_scale_s_tensor =
239+
static_cast<phi::DenseTensor*>(d_scale_s_ptr.impl().get());
240+
}
241+
242+
auto out_tensor = std::make_shared<phi::DenseTensor>();
243+
out_tensor->Resize(q_tensor->dims());
244+
245+
custom_kernel::fused_fp8_sdpa<phi::dtype::bfloat16>(
246+
*dev_ctx,
247+
*q_tensor,
248+
*k_tensor,
249+
*v_tensor,
250+
attn_mask ? *attn_mask_tensor : paddle::optional<phi::DenseTensor>(),
251+
d_scale_q ? *d_scale_q_tensor : paddle::optional<phi::DenseTensor>(),
252+
d_scale_k ? *d_scale_k_tensor : paddle::optional<phi::DenseTensor>(),
253+
d_scale_v ? *d_scale_v_tensor : paddle::optional<phi::DenseTensor>(),
254+
q_scale_s ? *q_scale_s_tensor : paddle::optional<phi::DenseTensor>(),
255+
q_scale_o ? *q_scale_o_tensor : paddle::optional<phi::DenseTensor>(),
256+
d_scale_s ? *d_scale_s_tensor : paddle::optional<phi::DenseTensor>(),
257+
scale,
258+
causal,
259+
out_tensor.get());
260+
261+
paddle::Tensor out(out_tensor);
262+
263+
return {out};
264+
}
265+
266+
std::vector<std::vector<int64_t>> FusedFp8SdpaForwardShape(
267+
const std::vector<int64_t>& query_states_shape,
268+
const std::vector<int64_t>& key_states_shape,
269+
const std::vector<int64_t>& value_states_shape) {
270+
int64_t bsz = query_states_shape[0];
271+
int64_t num_heads = query_states_shape[1];
272+
int64_t seq_len = query_states_shape[2];
273+
int head_dim = query_states_shape[3];
274+
return {{bsz, num_heads, seq_len, head_dim}};
275+
}
276+
277+
std::vector<paddle::DataType> FusedFp8SdpaForwardDtype(
278+
const paddle::DataType& query_states_dtype,
279+
const paddle::DataType& key_states_dtype,
280+
const paddle::DataType& value_states_dtype) {
281+
return {paddle::DataType::BFLOAT16};
282+
}
283+
284+
PD_BUILD_OP(fused_fp8_sdpa)
285+
.Inputs({
286+
"q",
287+
"k",
288+
"v",
289+
paddle::Optional("attn_mask"),
290+
paddle::Optional("d_scale_q"),
291+
paddle::Optional("d_scale_k"),
292+
paddle::Optional("d_scale_v"),
293+
paddle::Optional("q_scale_s"),
294+
paddle::Optional("q_scale_o"),
295+
paddle::Optional("d_scale_s"),
296+
})
297+
.Attrs({"causal: bool", "scaling_factor: float"})
298+
.Outputs({"out"})
299+
.SetKernelFn(PD_KERNEL(FusedFp8SdpaForward))
300+
.SetInferShapeFn(PD_INFER_SHAPE(FusedFp8SdpaForwardShape))
301+
.SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8SdpaForwardDtype));

0 commit comments

Comments
 (0)