@@ -399,9 +399,9 @@ __global__ void LayerNormBackwardComputeGradInput(
399399 const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
400400 const U *gamma, T *grad_input) {
401401#ifdef __HIPCC__
402- for (auto i1 = hipBlockIdx_y ; i1 < n1; i1 += hipGridDim_y ) {
402+ for (auto i1 = hipBlockIdx_x ; i1 < n1; i1 += hipGridDim_x ) {
403403#else
404- for (auto i1 = blockIdx .y ; i1 < n1; i1 += gridDim .y ) {
404+ for (auto i1 = blockIdx .x ; i1 < n1; i1 += gridDim .x ) {
405405#endif
406406 U sum_loss1 = U (0 );
407407 U sum_loss2 = U (0 );
@@ -867,9 +867,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
867867 constexpr int BDIMX1 = 32 ;
868868 constexpr int BDIMY1 = 4 ;
869869 dim3 threads1 (BDIMX1, BDIMY1, 1 );
870- const dim3 blocks1 (1 , batch_size, 1 );
871870 LayerNormBackwardComputeGradInput<
872- T, U, BDIMX1, BDIMY1><<<blocks1 , threads1, 0 , stream>>> (
871+ T, U, BDIMX1, BDIMY1><<<batch_size , threads1, 0 , stream>>> (
873872 d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
874873 break ;
875874 }
0 commit comments