Skip to content

Commit 086ee66

Browse files
Ensure that Arm64 correctly handles multiplication of simd by a 64-bit scalar (#106839)
1 parent c88c18c commit 086ee66

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20843,21 +20843,14 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2084320843
{
2084420844
GenTree** broadcastOp = nullptr;
2084520845

20846-
#if defined(TARGET_ARM64)
20847-
if (varTypeIsLong(simdBaseType))
20848-
{
20849-
break;
20850-
}
20851-
#endif // TARGET_ARM64
20852-
2085320846
if (varTypeIsArithmetic(op1))
2085420847
{
2085520848
broadcastOp = &op1;
2085620849

2085720850
#if defined(TARGET_ARM64)
2085820851
if (!varTypeIsByte(simdBaseType))
2085920852
{
20860-
// MultiplyByScalar requires the scalar op to be op2fGetHWIntrinsicIdForBinOp
20853+
// MultiplyByScalar requires the scalar op to be op2 for GetHWIntrinsicIdForBinOp
2086120854
needsReverseOps = true;
2086220855
}
2086320856
#endif // TARGET_ARM64
@@ -20870,7 +20863,12 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2087020863
if (broadcastOp != nullptr)
2087120864
{
2087220865
#if defined(TARGET_ARM64)
20873-
if (!varTypeIsByte(simdBaseType))
20866+
if (varTypeIsLong(simdBaseType))
20867+
{
20868+
// This is handled via emulation and the scalar is consumed directly
20869+
break;
20870+
}
20871+
else if (!varTypeIsByte(simdBaseType))
2087420872
{
2087520873
op2ForLookup = *broadcastOp;
2087620874
*broadcastOp = gtNewSimdCreateScalarUnsafeNode(TYP_SIMD8, *broadcastOp, simdBaseJitType, 8);
@@ -21274,24 +21272,26 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2127421272
#elif defined(TARGET_ARM64)
2127521273
if (varTypeIsLong(simdBaseType))
2127621274
{
21277-
GenTree** op1ToDup = &op1;
21278-
GenTree** op2ToDup = &op2;
21275+
GenTree** op2ToDup = nullptr;
2127921276

21280-
if (!varTypeIsArithmetic(op1))
21281-
{
21282-
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
21283-
op1ToDup = &op1->AsHWIntrinsic()->Op(1);
21284-
}
21277+
assert(varTypeIsSIMD(op1));
21278+
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
21279+
GenTree** op1ToDup = &op1->AsHWIntrinsic()->Op(1);
2128521280

21286-
if (!varTypeIsArithmetic(op2))
21281+
if (varTypeIsSIMD(op2))
2128721282
{
2128821283
op2 = gtNewSimdToScalarNode(TYP_LONG, op2, simdBaseJitType, simdSize);
2128921284
op2ToDup = &op2->AsHWIntrinsic()->Op(1);
2129021285
}
2129121286

2129221287
// lower = op1.GetElement(0) * op2.GetElement(0)
2129321288
GenTree* lower = gtNewOperNode(GT_MUL, TYP_LONG, op1, op2);
21294-
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);
21289+
21290+
if (op2ToDup == nullptr)
21291+
{
21292+
op2ToDup = &lower->AsOp()->gtOp2;
21293+
}
21294+
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);
2129521295

2129621296
if (simdSize == 8)
2129721297
{
@@ -21303,10 +21303,8 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2130321303
GenTree* op1Dup = fgMakeMultiUse(op1ToDup);
2130421304
GenTree* op2Dup = fgMakeMultiUse(op2ToDup);
2130521305

21306-
if (!varTypeIsArithmetic(op1Dup))
21307-
{
21308-
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);
21309-
}
21306+
assert(!varTypeIsArithmetic(op1Dup));
21307+
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);
2131021308

2131121309
if (!varTypeIsArithmetic(op2Dup))
2131221310
{
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Runtime.CompilerServices;
6+
using System.Runtime.Intrinsics;
7+
using Xunit;
8+
9+
public class Runtime_106838
10+
{
11+
[MethodImpl(MethodImplOptions.NoInlining)]
12+
private static Vector128<ulong> Problem(Vector128<ulong> vector) => vector * 5UL;
13+
14+
[Fact]
15+
public static void TestEntryPoint()
16+
{
17+
Vector128<ulong> result = Problem(Vector128.Create<ulong>(5));
18+
Assert.Equal(Vector128.Create<ulong>(25), result);
19+
}
20+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
<PropertyGroup>
3+
<Optimize>True</Optimize>
4+
</PropertyGroup>
5+
<ItemGroup>
6+
<Compile Include="$(MSBuildProjectName).cs" />
7+
</ItemGroup>
8+
</Project>

0 commit comments

Comments
 (0)