-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[CPU][IBM Z] Fix BF16 support and vectorize math operations for s390x #28926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
16451b2 to
ea15020
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant performance improvements for the s390x architecture by fixing a BF16 byte-ordering bug and adding vectorized implementations for several math operations. The changes are generally good, but there are several opportunities for further performance optimization in the new vectorized functions. Specifically, some functions contain inefficient patterns like calling expensive operations inside loops or unnecessarily unpacking vectors to scalars. I've provided specific suggestions to improve these areas.
csrc/cpu/cpu_types_vxe.hpp
Outdated
| FP32Vec8 tanh() const { | ||
| // TODO: Vectorize this | ||
| AliasReg ar; | ||
| ar.reg = reg; | ||
| f32x4x4_t ret; | ||
| ret.val[0][0] = std::tanh(ar.values[0]); | ||
| ret.val[0][1] = std::tanh(ar.values[1]); | ||
| ret.val[0][2] = std::tanh(ar.values[2]); | ||
| ret.val[0][3] = std::tanh(ar.values[3]); | ||
| ret.val[1][0] = std::tanh(ar.values[4]); | ||
| ret.val[1][1] = std::tanh(ar.values[5]); | ||
| ret.val[1][2] = std::tanh(ar.values[6]); | ||
| ret.val[1][3] = std::tanh(ar.values[7]); | ||
| return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); | ||
| } | ||
|
|
||
| FP32Vec8 er() const { | ||
| // TODO: Vectorize this | ||
| AliasReg ar; | ||
| ar.reg = reg; | ||
| f32x4x4_t ret; | ||
| ret.val[0][0] = std::erf(ar.values[0]); | ||
| ret.val[0][1] = std::erf(ar.values[1]); | ||
| ret.val[0][2] = std::erf(ar.values[2]); | ||
| ret.val[0][3] = std::erf(ar.values[3]); | ||
| ret.val[1][0] = std::erf(ar.values[4]); | ||
| ret.val[1][1] = std::erf(ar.values[5]); | ||
| ret.val[1][2] = std::erf(ar.values[6]); | ||
| ret.val[1][3] = std::erf(ar.values[7]); | ||
| return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); | ||
| // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) | ||
| const __vector float one = vec_splats(1.0f); | ||
| const __vector float two = vec_splats(2.0f); | ||
| const __vector float zero = vec_splats(0.0f); | ||
| const __vector float sat = vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x) | ||
|
|
||
| f32x4x2_t out; | ||
|
|
||
| for (int i = 0; i < 2; i++) { | ||
| __vector float x = reg.val[i]; | ||
| __vector float ax = vec_abs(x); | ||
|
|
||
| // sign(x): +1 or -1 | ||
| __vector float sign = vec_sel(vec_splats(-1.0f), one, | ||
| vec_cmpgt(x, zero)); | ||
|
|
||
| // saturation mask: |x| > sat | ||
| __vector __bool int saturated = vec_cmpgt(ax, sat); | ||
|
|
||
| // 2x | ||
| __vector float two_x = vec_mul(x, two); | ||
|
|
||
| // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp() | ||
| f32x4x2_t tmp; | ||
| tmp.val[0] = two_x; | ||
| tmp.val[1] = two_x; | ||
| FP32Vec8 exp_2x_vec(tmp); | ||
|
|
||
| FP32Vec8 e2x = exp_2x_vec.exp(); | ||
| __vector float e = e2x.reg.val[i]; | ||
|
|
||
| // tanh(x) = (e - 1) / (e + 1) | ||
| __vector float num = vec_sub(e, one); | ||
| __vector float den = vec_add(e, one); | ||
|
|
||
| __vector float t = vec_div(num, den); | ||
|
|
||
| // For large |x|, clamp to sign(x) | ||
| out.val[i] = vec_sel(t, sign, saturated); | ||
| } | ||
|
|
||
| return FP32Vec8(out); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exp() call is inside a loop, which is inefficient. It's called for each of the two lanes of FP32Vec8, and each call computes the exponential for a full temporary vector. This can be optimized by calling exp() once on the 2*x vector outside the loop.
FP32Vec8 tanh() const {
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
const __vector float one = vec_splats(1.0f);
const __vector float two = vec_splats(2.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float sat = vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
f32x4x2_t two_x_lanes;
for (int i = 0; i < 2; i++) {
two_x_lanes.val[i] = vec_mul(reg.val[i], two);
}
FP32Vec8 e2x = FP32Vec8(two_x_lanes).exp();
f32x4x2_t out;
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
__vector float ax = vec_abs(x);
// sign(x): +1 or -1
__vector float sign = vec_sel(vec_splats(-1.0f), one,
vec_cmpgt(x, zero));
// saturation mask: |x| > sat
__vector __bool int saturated = vec_cmpgt(ax, sat);
// tanh(x) = (e - 1) / (e + 1)
__vector float num = vec_sub(e2x.reg.val[i], one);
__vector float den = vec_add(e2x.reg.val[i], one);
__vector float t = vec_div(num, den);
// For large |x|, clamp to sign(x)
out.val[i] = vec_sel(t, sign, saturated);
}
return FP32Vec8(out);
}
csrc/cpu/cpu_types_vxe.hpp
Outdated
| FP32Vec8 er() const { | ||
| // A&S 7.1.26 approximation: | ||
| // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t * exp(-x^2)) | ||
| // t = 1 / (1 + p*|x|), p = 0.3275911 | ||
|
|
||
| const __vector float one = vec_splats(1.0f); | ||
| const __vector float zero = vec_splats(0.0f); | ||
| const __vector float p = vec_splats(0.3275911f); | ||
|
|
||
| // Polynomial coeffs | ||
| const __vector float a1 = vec_splats(0.254829592f); | ||
| const __vector float a2 = vec_splats(-0.284496736f); | ||
| const __vector float a3 = vec_splats(1.421413741f); | ||
| const __vector float a4 = vec_splats(-1.453152027f); | ||
| const __vector float a5 = vec_splats(1.061405429f); | ||
|
|
||
| // Threshold where erf(x) ~ sign(x) | ||
| const __vector float sat = vec_splats(6.0f); | ||
|
|
||
| f32x4x2_t out; | ||
|
|
||
| for (int lane = 0; lane < 2; lane++) { | ||
| __vector float x = reg.val[lane]; | ||
| __vector float ax = vec_abs(x); | ||
|
|
||
| // sign(x) | ||
| __vector float sign = vec_sel(vec_splats(-1.0f), one, | ||
| vec_cmpgt(x, zero)); | ||
|
|
||
| // |x| > 6 → erf(x) = ±1 | ||
| __vector __bool int saturated = vec_cmpgt(ax, sat); | ||
|
|
||
| // t = 1 / (1 + p * |x|) | ||
| __vector float t = vec_madd(p, ax, one); | ||
| t = vec_div(one, t); | ||
|
|
||
| // poly = a5 | ||
| __vector float poly = a5; | ||
| poly = vec_madd(poly, t, a4); | ||
| poly = vec_madd(poly, t, a3); | ||
| poly = vec_madd(poly, t, a2); | ||
| poly = vec_madd(poly, t, a1); | ||
|
|
||
| // full polynomial: poly = poly * t | ||
| poly = vec_mul(poly, t); | ||
|
|
||
| // Compute exp(-x^2) | ||
| __vector float x2 = vec_mul(x, x); | ||
| __vector float neg_x2 = vec_neg(x2); | ||
|
|
||
| f32x4x2_t tmp; | ||
| tmp.val[0] = neg_x2; | ||
| tmp.val[1] = neg_x2; | ||
| FP32Vec8 exp_neg_x2(tmp); | ||
|
|
||
| FP32Vec8 e = exp_neg_x2.exp(); | ||
| __vector float ex = e.reg.val[lane]; | ||
|
|
||
| // erf(x) = sign * (1 - poly * exp(-x^2)) | ||
| __vector float term = vec_mul(poly, ex); | ||
| __vector float y = vec_sub(one, term); | ||
| y = vec_mul(y, sign); | ||
|
|
||
| // saturated → ±1 | ||
| __vector float sat_val = vec_mul(sign, one); | ||
| out.val[lane] = vec_sel(y, sat_val, saturated); | ||
| } | ||
|
|
||
| return FP32Vec8(out); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exp() call is inside a loop, which is inefficient. It's called for each of the two lanes of FP32Vec8, and each call computes the exponential for a full temporary vector. This can be optimized by calling exp() once on the -x^2 vector outside the loop.
FP32Vec8 er() const {
// A&S 7.1.26 approximation:
// erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t * exp(-x^2))
// t = 1 / (1 + p*|x|), p = 0.3275911
const __vector float one = vec_splats(1.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float p = vec_splats(0.3275911f);
// Polynomial coeffs
const __vector float a1 = vec_splats(0.254829592f);
const __vector float a2 = vec_splats(-0.284496736f);
const __vector float a3 = vec_splats(1.421413741f);
const __vector float a4 = vec_splats(-1.453152027f);
const __vector float a5 = vec_splats(1.061405429f);
// Threshold where erf(x) ~ sign(x)
const __vector float sat = vec_splats(6.0f);
f32x4x2_t neg_x2_lanes;
for (int lane = 0; lane < 2; lane++) {
__vector float x2 = vec_mul(reg.val[lane], reg.val[lane]);
neg_x2_lanes.val[lane] = vec_neg(x2);
}
FP32Vec8 e = FP32Vec8(neg_x2_lanes).exp();
f32x4x2_t out;
for (int lane = 0; lane < 2; lane++) {
__vector float x = reg.val[lane];
__vector float ax = vec_abs(x);
// sign(x)
__vector float sign = vec_sel(vec_splats(-1.0f), one,
vec_cmpgt(x, zero));
// |x| > 6 → erf(x) = ±1
__vector __bool int saturated = vec_cmpgt(ax, sat);
// t = 1 / (1 + p * |x|)
__vector float t = vec_madd(p, ax, one);
t = vec_div(one, t);
// poly = a5
__vector float poly = a5;
poly = vec_madd(poly, t, a4);
poly = vec_madd(poly, t, a3);
poly = vec_madd(poly, t, a2);
poly = vec_madd(poly, t, a1);
// full polynomial: poly = poly * t
poly = vec_mul(poly, t);
__vector float ex = e.reg.val[lane];
// erf(x) = sign * (1 - poly * exp(-x^2))
__vector float term = vec_mul(poly, ex);
__vector float y = vec_sub(one, term);
y = vec_mul(y, sign);
// saturated → ±1
__vector float sat_val = vec_mul(sign, one);
out.val[lane] = vec_sel(y, sat_val, saturated);
}
return FP32Vec8(out);
}| FP32Vec8 rcp() const { | ||
| AliasReg in, out; | ||
| in.reg = reg; | ||
|
|
||
| for (int i = 0; i < VEC_ELEM_NUM; ++i) { | ||
| out.values[i] = 1.0f / in.values[i]; | ||
| } | ||
| return FP32Vec8(out.reg); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| float reduce_max() const { | ||
| AliasReg ar; | ||
| ar.reg = reg; | ||
| float result = ar.values[0]; | ||
| unroll_loop<int, VEC_ELEM_NUM>( | ||
| [&result, &ar](int i) { | ||
| if (ar.values[i] > result) result = ar.values[i]; | ||
| }); | ||
| return result; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduce_max function is implemented with a scalar loop. This can be optimized by using vector instructions to perform a parallel reduction.
float reduce_max() const {
__vector float v01 = vec_max(reg.val[0], reg.val[1]);
__vector float v23 = vec_max(reg.val[2], reg.val[3]);
__vector float v = vec_max(v01, v23);
v = vec_max(v, vec_sld(v, v, 8));
v = vec_max(v, vec_sld(v, v, 4));
return vec_extract(v, 0);
}| for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { | ||
| FP32Vec8 v(input + i); | ||
| FP32Vec8::AliasReg ar; | ||
| ar.reg = v.reg; | ||
| for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { | ||
| if (ar.values[j] > max_val) max_val = ar.values[j]; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop to find the maximum value is inefficient as it unpacks the vector and finds the maximum in a scalar loop. This pattern is repeated in other parts of softmax_fp32vec8 and rmsnorm_fp32vec8. To improve performance, these loops should be fully vectorized. For finding the max, you can use vector max operations and a final horizontal max reduction. This would require adding max() and reduce_max() methods to FP32Vec8, similar to what you've done for FP32Vec16.
| for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) { | ||
| FP32Vec8 x_vec(input + i); | ||
|
|
||
| FP32Vec8 sq = x_vec * x_vec; | ||
|
|
||
| FP32Vec8::AliasReg ar; | ||
| ar.reg = sq.reg; | ||
| for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) { | ||
| sum_sq += ar.values[j]; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| #if !defined(__powerpc__) && !defined(__s390x__) | ||
| template <> | ||
| struct VecTypeTrait<c10::Half> { | ||
| using vec_t = vec_op::FP16Vec16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore VecTypeTraitc10::Half for s390x
The new guard in VecTypeTrait<c10::Half> (#if !defined(__powerpc__) && !defined(__s390x__)) now completely removes the specialization whenever __s390x__ is defined. The CPU attention kernels (cpu_attn_vec.hpp/cpu_attn_vec16.hpp) still instantiate VecTypeTrait<c10::Half>::vec_t when building float16 variants, so on s390 builds this alias now resolves to void and those translation units fail to compile, effectively dropping FP16 CPU support even though FP16Vec16 is still defined (aliased to FP32Vec16). The specialization should remain and use the FP32 alias instead of being disabled for s390.
Useful? React with 👍 / 👎.
csrc/cpu/cpu_types_vxe.hpp
Outdated
| __vector float y = vec_mul(x, log2e); | ||
|
|
||
| __vector float kf = vec_floor(y); | ||
| __vector float r = vec_sub(y, kf); | ||
|
|
||
| __vector signed int k = vec_signed(kf); | ||
| const __vector signed int min_k = vec_splats((signed int)-126); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vec_signed intrinsic is undefined
FP32Vec8::exp() now calls vec_signed(kf) (lines 214‑220) to turn the floored exponent into an integer, but neither this file nor the toolchain headers provide a definition for vec_signed (it does not appear anywhere else in the repo). As written the file will not compile because the compiler cannot resolve this intrinsic. The conversion needs to use an existing helper (e.g. vec_cts or an explicit cast) instead of referencing an undefined function.
Useful? React with 👍 / 👎.
540a570 to
3df78d2
Compare
dcd9b63 to
efe7993
Compare
… (VXE) - Fix BF16 byte ordering for big-endian architecture - Vectorize exp(), tanh(), erf() functions with polynomial approximations - Add FMA intrinsics (fma, fms, nfma, nfms) using vec_madd/vec_msub - Improve BF16 rounding with round-to-nearest-even - Fix prefetch implementation - Add sigmoid, gelu_tanh, gelu_erf, rcp, rsqrt operations - Implement softmax_fp32vec8 and rmsnorm_fp32vec8 kernels - Fix FP16 support by aliasing to FP32Vec16 - Exclude s390x from FP16 vector trait in cpu_attn_impl.hpp Signed-off-by: Rehan Khan <[email protected]>
efe7993 to
756255b
Compare
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]> Signed-off-by: Runkai Tao <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]> Signed-off-by: Xingyu Liu <[email protected]>
…vllm-project#28926) Signed-off-by: Rehan Khan <[email protected]>
Purpose
This PR fixes BF16 (bfloat16) support issue by fixing the byte ordering and adds comprehensive vectorized mathematical operations for IBM Z s390x architecture using VXE.
Key Issues Addressed:
std::exp,std::tanh,std::erf) with optimized vector implementationsTest Plan
Deploy vLLM service and check the inferencing output
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.