Skip to content

Commit 26a0f9b

Browse files
committed
update for windows
1 parent 3262a2f commit 26a0f9b

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

paddle/phi/kernels/fusion/cpu/fused_layer_norm_avx_kernel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,12 @@ void LayerNormFunc(const T* x_data,
155155
if (norm_bias_data) {
156156
vbeta = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
157157
}
158-
__m512 vy = (vx - vmean) * vgamma * vvar + vbeta;
158+
// (vx - vmean) * vgamma * vvar + vbeta
159+
__m512 vy;
160+
vx = _mm512_mask_sub_ps(vx, mask, vx, vmean);
161+
vx = _mm512_mask_mul_ps(vx, mask, vx, vgamma);
162+
vx = _mm512_mask_mul_ps(vx, mask, vx, vvar);
163+
vy = _mm512_mask_add_ps(vy, mask, vx, vbeta);
159164
_mm512_mask_storeu_ps(py + col, mask, vy);
160165
}
161166
}

paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ void softmax_sum_max(float* AB,
257257
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
258258

259259
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
260-
vx = vexp(vx * vrefac - vmax);
260+
vx = _mm512_mask_mul_ps(vx, mask, vx, vrefac);
261+
vx = _mm512_mask_sub_ps(vx, mask, vx, vmax);
262+
vx = vexp(vx);
261263

262264
_mm512_mask_storeu_ps(buf + off, mask, vx);
263265

@@ -275,8 +277,7 @@ void softmax_sum_max(float* AB,
275277
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
276278

277279
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
278-
vx = vx * vrsum;
279-
280+
vx = _mm512_mask_mul_ps(vx, mask, vx, vrsum);
280281
_mm512_mask_storeu_ps(buf + off, mask, vx);
281282
}
282283
}
@@ -301,7 +302,10 @@ void update_out_blk(float* output,
301302
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
302303
__m512 vout = _mm512_maskz_loadu_ps(mask, outbuf + off);
303304
__m512 vabc = _mm512_maskz_loadu_ps(mask, buf + off);
304-
__m512 vupt = vout * merr * vfac + vabc;
305+
vout = _mm512_mask_mul_ps(vout, mask, vout, merr);
306+
vout = _mm512_mask_mul_ps(vout, mask, vout, vfac);
307+
__m512 vupt;
308+
vupt = _mm512_mask_add_ps(vupt, mask, vout, vabc);
305309
_mm512_mask_storeu_ps(outbuf + off, mask, vupt);
306310
}
307311
pre_sum[i] = sum[i];

test/legacy_test/test_fused_layernorm_op.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,6 @@ def setUp(self):
961961
self.epsilon = 1e-5
962962
self.residual_alpha = np.random.uniform(low=0.1, high=1.1, size=[1])
963963

964-
self.quant_scale = 0.15
965-
self.quant_round_type = 1
966-
self.quant_max_bound = 127
967-
self.quant_min_bound = -127
968964
self.place = paddle.CPUPlace()
969965

970966
def check_layernorm(self, x_np, gamma_np, beta_np, dtype):

0 commit comments

Comments
 (0)