File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed
paddle/phi/kernels/fusion/gpu Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -125,10 +125,18 @@ __global__ void VectorizedFusedRopeWithRotateEveryTwoKernel(
125125 MPType p0 = static_cast <MPType>(input[pr_index]);
126126 MPType p1 = static_cast <MPType>(input[ls_index]);
127127
128- result[pr_index] =
129- cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1;
130- result[ls_index] =
131- cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0;
128+ if (sign == 1 ) {
129+ result[pr_index] = cos_value[pr_index] * p0;
130+ result[pr_index] -= sin_value[pr_index] * p1;
131+
132+ result[ls_index] = sin_value[ls_index] * p0;
133+ result[ls_index] += cos_value[ls_index] * p1;
134+ } else if (sign == -1 ) {
135+ result[pr_index] =
136+ cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
137+ result[ls_index] =
138+ cos_value[ls_index] * p1 - sin_value[pr_index] * p0;
139+ }
132140
133141 store[pr_index] = static_cast <T>(result[pr_index]);
134142 store[ls_index] = static_cast <T>(result[ls_index]);
You can’t perform that action at this time.
0 commit comments