Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class AArch64AsmPrinter : public AsmPrinter {

void LowerJumpTableDest(MCStreamer &OutStreamer, const MachineInstr &MI);

void LowerHardenedBRJumpTable(const MachineInstr &MI);

void LowerMOPS(MCStreamer &OutStreamer, const MachineInstr &MI);

void LowerSTACKMAP(MCStreamer &OutStreamer, StackMaps &SM,
Expand Down Expand Up @@ -1310,6 +1312,138 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
.addImm(Size == 4 ? 0 : 2));
}

void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
unsigned InstsEmitted = 0;

const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
assert(MJTI && "Can't lower jump-table dispatch without JTI");

const std::vector<MachineJumpTableEntry> &JTs = MJTI->getJumpTables();
assert(!JTs.empty() && "Invalid JT index for jump-table dispatch");

// Emit:
// mov x17, #<size of table> ; depending on table size, with MOVKs
// cmp x16, x17 ; or #imm if table size fits in 12-bit
// csel x16, x16, xzr, ls ; check for index overflow
//
// adrp x17, Ltable@PAGE ; materialize table address
// add x17, Ltable@PAGEOFF
// ldrsw x16, [x17, x16, lsl #2] ; load table entry
//
// Lanchor:
// adr x17, Lanchor ; compute target address
// add x16, x17, x16
// br x16 ; branch to target

MachineOperand JTOp = MI.getOperand(0);

unsigned JTI = JTOp.getIndex();
assert(!AArch64FI->getJumpTableEntryPCRelSymbol(JTI) &&
"unsupported compressed jump table");

const uint64_t NumTableEntries = JTs[JTI].MBBs.size();

// cmp only supports a 12-bit immediate. If we need more, materialize the
// immediate, using x17 as a scratch register.
uint64_t MaxTableEntry = NumTableEntries - 1;
if (isUInt<12>(MaxTableEntry)) {
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXri)
.addReg(AArch64::XZR)
.addReg(AArch64::X16)
.addImm(MaxTableEntry)
.addImm(0));
++InstsEmitted;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you probably want to create a helper lambda which will execute both EmitToStreamer and ++InstsEmitted - these are always used together in AArch64AsmPrinter::LowerHardenedBRJumpTable. You can also use this lambda to get rid of *OutStreamer in each EmitToStreamer invocation (just set this argument inside the lambda since it's always the same).

Feel free to ignore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do that for the whole file someplace else

} else {
EmitToStreamer(*OutStreamer,
MCInstBuilder(AArch64::MOVZXi)
.addReg(AArch64::X17)
.addImm(static_cast<uint16_t>(MaxTableEntry))
.addImm(0));
++InstsEmitted;
// It's sad that we have to manually materialize instructions, but we can't
// trivially reuse the main pseudo expansion logic.
// A MOVK sequence is easy enough to generate and handles the general case.
for (int Offset = 16; Offset < 64; Offset += 16) {
if ((MaxTableEntry >> Offset) == 0)
break;
EmitToStreamer(*OutStreamer,
MCInstBuilder(AArch64::MOVKXi)
.addReg(AArch64::X17)
.addReg(AArch64::X17)
.addImm(static_cast<uint16_t>(MaxTableEntry >> Offset))
.addImm(Offset));
++InstsEmitted;
}
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0));
++InstsEmitted;
}

// This picks entry #0 on failure.
// We might want to trap instead.
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::CSELXr)
.addReg(AArch64::X16)
.addReg(AArch64::X16)
.addReg(AArch64::XZR)
.addImm(AArch64CC::LS));
++InstsEmitted;

// Prepare the @PAGE/@PAGEOFF low/high operands.
MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
MCOperand JTMCHi, JTMCLo;

JTMOHi.setTargetFlags(AArch64II::MO_PAGE);
JTMOLo.setTargetFlags(AArch64II::MO_PAGEOFF | AArch64II::MO_NC);

MCInstLowering.lowerOperand(JTMOHi, JTMCHi);
MCInstLowering.lowerOperand(JTMOLo, JTMCLo);

EmitToStreamer(
*OutStreamer,
MCInstBuilder(AArch64::ADRP).addReg(AArch64::X17).addOperand(JTMCHi));
++InstsEmitted;

EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
.addReg(AArch64::X17)
.addReg(AArch64::X17)
.addOperand(JTMCLo)
.addImm(0));
++InstsEmitted;

EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addReg(AArch64::X16)
.addImm(0)
.addImm(1));
++InstsEmitted;

MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
AArch64FI->setJumpTableEntryInfo(JTI, 4, AdrLabel);

OutStreamer->emitLabel(AdrLabel);
EmitToStreamer(
*OutStreamer,
MCInstBuilder(AArch64::ADR).addReg(AArch64::X17).addExpr(AdrLabelE));
++InstsEmitted;

EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addReg(AArch64::X16)
.addImm(0));
++InstsEmitted;

EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR).addReg(AArch64::X16));
++InstsEmitted;

assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
}

void AArch64AsmPrinter::LowerMOPS(llvm::MCStreamer &OutStreamer,
const llvm::MachineInstr &MI) {
unsigned Opcode = MI.getOpcode();
Expand Down Expand Up @@ -2177,6 +2311,10 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
LowerJumpTableDest(*OutStreamer, *MI);
return;

case AArch64::BR_JumpTable:
LowerHardenedBRJumpTable(*MI);
return;

case AArch64::FMOVH0:
case AArch64::FMOVS0:
case AArch64::FMOVD0:
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10678,6 +10678,21 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
AFI->setJumpTableEntryInfo(JTI, 4, nullptr);

// With jump-table-hardening, we only expand the full jump table dispatch
// sequence later, to guarantee the integrity of the intermediate values.
if (DAG.getMachineFunction().getFunction().hasFnAttribute(
"jump-table-hardening") ||
Subtarget->getTargetTriple().isArm64e()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't check against arm64e duplicating the check against function attribute? I'm not aware of full arm64e spec, but I suppose that frontend will set "jump-table-hardening" attribute when compiling for arm64e, so there is probably no need for this check here. The same applies to GlobalISel.

Feel free to ignore - I'm OK with such Apple-specific stuff if it's considered essential.

assert(Subtarget->isTargetMachO() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain what is MachO-specific here? It looks like that at least Linux+ELF could also we supported w/o changing code logic, just by deleting this assertion. ELF tests will require a slight change since assembly syntax and label names are a bit different, so I'm happy to add them by myself later if it's out of scope of the patch.

If there is definitely smth MachO-specific, shouldn't it also be checked in GlobalISel? Alternatively, the check can be moved to the pseudo expansion.

"hardened jump-table not yet supported on non-macho");
SDValue X16Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::X16,
Entry, SDValue());
SDNode *B = DAG.getMachineNode(AArch64::BR_JumpTable, DL, MVT::Other,
DAG.getTargetJumpTable(JTI, MVT::i32),
X16Copy.getValue(0), X16Copy.getValue(1));
return SDValue(B, 0);
}

SDNode *Dest =
DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
Expand Down
26 changes: 26 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,32 @@ def JumpTableDest8 : Pseudo<(outs GPR64:$dst, GPR64sp:$scratch),
Sched<[]>;
}

// A hardened but more expensive version of jump-table dispatch.
// This combines the target address computation (otherwise done using the
// JumpTableDest pseudos above) with the branch itself (otherwise done using
// a plain BR) in a single non-attackable sequence.
//
// We take the final entry index as an operand to allow isel freedom. This does
// mean that the index can be attacker-controlled. To address that, we also do
// limited checking of the offset, mainly ensuring it still points within the
// jump-table array. When it doesn't, this branches to the first entry.
//
// This is intended for use in conjunction with ptrauth for other code pointers,
// to avoid signing jump-table entries and turning them into pointers.
//
// Entry index is passed in x16. Clobbers x16/x17/nzcv.
let isNotDuplicable = 1 in
def BR_JumpTable : Pseudo<(outs), (ins i32imm:$jti), []>, Sched<[]> {
let isBranch = 1;
let isTerminator = 1;
let isIndirectBranch = 1;
let isBarrier = 1;
let isNotDuplicable = 1;
let Defs = [X16,X17,NZCV];
let Uses = [X16];
let Size = 44; // 28 fixed + 16 variable, for table size materialization
}

// Space-consuming pseudo to aid testing of placement and reachability
// algorithms. Immediate operand is the number of bytes this "instruction"
// occupies; register operands can be used to enforce dependency and constrain
Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3597,10 +3597,22 @@ bool AArch64InstructionSelector::selectBrJT(MachineInstr &I,
unsigned JTI = I.getOperand(1).getIndex();
Register Index = I.getOperand(2).getReg();

MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
if (MF->getFunction().hasFnAttribute("jump-table-hardening") ||
STI.getTargetTriple().isArm64e()) {
if (TM.getCodeModel() != CodeModel::Small)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing smth, but why we don't support large code model here and support it in AArch64TargetLowering::LowerBR_JT? The offset computation is done in pseudo expansion and it's common for both ISel's, and full 64-bit offsets are supported via movk. So, we can probably support large code model here as well?

If such behavior is expected - the error message below looks untested. I'm OK with adding tests as a follow-up patch later if it's more convenient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, SDAG/GISel should both have the same support. The situation with the large code model is more complicated.
Currently, the AsmPrinter expansion materializes the jump-table address using adrp/add. On MachO, that's okay(-ish) for the large code-model as well, but that's not what's done on ELF (per AArch64TargetLowering::LowerJumpTable).
I don't see an explicit rationale in the code or history, but I could justify it with ELF always using a different section for the jump-table itself, vs. MachO ~always having it inline with the function. In the latter case, I think we can usually get away without the 4x movk large code-model materialization sequence, so we end up with the same desired adrp/add codegen for large/small code models on MachO, but not ELF.
Of course we can teach the expansion about non-adrp/add materialization variants, but let's add that separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I think it would be OK to explicitly not support large code model at all (both in SDAG/GISel) in context of this PR, and add such support separately to have main logic for the most common use case already merged.

report_fatal_error("Unsupported code-model for hardened jump-table");

MIB.buildCopy({AArch64::X16}, I.getOperand(2).getReg());
MIB.buildInstr(AArch64::BR_JumpTable)
.addJumpTableIndex(I.getOperand(1).getIndex());
I.eraseFromParent();
return true;
}

Register TargetReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
Register ScratchReg = MRI.createVirtualRegister(&AArch64::GPR64spRegClass);

MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
auto JumpTableInst = MIB.buildInstr(AArch64::JumpTableDest32,
{TargetReg, ScratchReg}, {JTAddr, Index})
.addJumpTableIndex(JTI);
Expand Down
53 changes: 53 additions & 0 deletions llvm/test/CodeGen/AArch64/hardened-jump-table-br.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 | FileCheck %s
; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -code-model=large | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it would be nice to test large offsets as well, but I'm not sure how we can force such offsets artificially. So, I'm OK with current test, but large offsets tests are "nice to have" if you know a good way to implement them.

; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -global-isel -global-isel-abort=1 | FileCheck %s

; CHECK-LABEL: test_jumptable:
; CHECK: mov w[[INDEX:[0-9]+]], w0
; CHECK: cmp x[[INDEX]], #5
; CHECK: csel [[INDEX2:x[0-9]+]], x[[INDEX]], xzr, ls
; CHECK-NEXT: adrp [[JTPAGE:x[0-9]+]], LJTI0_0@PAGE
; CHECK-NEXT: add x[[JT:[0-9]+]], [[JTPAGE]], LJTI0_0@PAGEOFF
; CHECK-NEXT: ldrsw [[OFFSET:x[0-9]+]], [x[[JT]], [[INDEX2]], lsl #2]
; CHECK-NEXT: Ltmp0:
; CHECK-NEXT: adr [[TABLE:x[0-9]+]], Ltmp0
; CHECK-NEXT: add [[DEST:x[0-9]+]], [[TABLE]], [[OFFSET]]
; CHECK-NEXT: br [[DEST]]

define i32 @test_jumptable(i32 %in) "jump-table-hardening" {

switch i32 %in, label %def [
i32 0, label %lbl1
i32 1, label %lbl2
i32 2, label %lbl3
i32 4, label %lbl4
i32 5, label %lbl5
]

def:
ret i32 0

lbl1:
ret i32 1

lbl2:
ret i32 2

lbl3:
ret i32 4

lbl4:
ret i32 8

lbl5:
ret i32 10

}

; CHECK: LJTI0_0:
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0