diff --git a/sway-ir/src/block.rs b/sway-ir/src/block.rs index f71bad17445..d69b2759e99 100644 --- a/sway-ir/src/block.rs +++ b/sway-ir/src/block.rs @@ -455,6 +455,12 @@ impl Block { context.blocks[self.0].instructions = insts; } + pub fn insert_instructions_after(&self, context: &mut Context, value: Value, insts: impl IntoIterator) { + let block_ins = &mut context.blocks[self.0].instructions; + let pos = block_ins.iter().position(|x| x == &value).unwrap(); + block_ins.splice(pos+1..pos+1, insts); + } + /// Replace an instruction in this block with another. Will return a ValueNotFound on error. /// Any use of the old instruction value will also be replaced by the new value throughout the /// owning function if `replace_uses` is set. diff --git a/sway-ir/src/optimize.rs b/sway-ir/src/optimize.rs index 299c51c0332..5db7ca21fb1 100644 --- a/sway-ir/src/optimize.rs +++ b/sway-ir/src/optimize.rs @@ -41,6 +41,8 @@ pub mod sroa; pub use sroa::*; pub mod fn_dedup; pub use fn_dedup::*; +pub mod branchless; +pub use branchless::*; mod target_fuel; diff --git a/sway-ir/src/optimize/branchless.rs b/sway-ir/src/optimize/branchless.rs new file mode 100644 index 00000000000..84402f3eefc --- /dev/null +++ b/sway-ir/src/optimize/branchless.rs @@ -0,0 +1,150 @@ +use sway_features::ExperimentalFeatures; + +use crate::{AnalysisResults, Block, BranchToWithArgs, Constant, ConstantContent, Context, Function, InstOp, IrError, Pass, PassMutability, ScopedPass, Value, ValueDatum}; + + +pub const BRANCHLESS_NAME: &str = "branchless"; + +pub fn create_branchless() -> Pass { + Pass { + name: BRANCHLESS_NAME, + descr: "Branchless", + deps: vec![], + runner: ScopedPass::FunctionPass(PassMutability::Transform(branchless)), + } +} + +// check if a block simple calls another block with a u64 +fn is_block_simple_integer<'a>(context: &'a Context, branch: &BranchToWithArgs) -> Option<(&'a Constant, &'a Block)> { + let b = &context.blocks[branch.block.0]; + + if b.instructions.len() > 1 { + return None; + } + + let v = &b.instructions[0]; + let v = &context.values[v.0]; + match &v.value { + crate::ValueDatum::Instruction(i) => match &i.op { + InstOp::Branch(branch) => { + if branch.args.len() != 1 { + return None; + } + + let arg0 = &context.values[branch.args[0].0]; + match &arg0.value { + crate::ValueDatum::Constant(constant) => Some((constant, &branch.block)), + _ => None, + } + }, + _ => None, + }, + _ => None, + } +} + +fn find_cbr(context: &mut Context, function: Function) -> Option<(Block, Value, Block, Value, Constant, Constant)> { + for (block, value) in function.instruction_iter(context) { + match &context.values[value.0].value { + ValueDatum::Argument(_) => {}, + ValueDatum::Constant(_) => {}, + ValueDatum::Instruction(instruction) => { + match &instruction.op { + InstOp::ConditionalBranch { cond_value, true_block, false_block } => { + let target_block_true = is_block_simple_integer(context, &true_block); + let target_block_false = is_block_simple_integer(context, &false_block); + + // both branches call the same block + match (target_block_true, target_block_false) { + (Some((constant_true, target_block_true)), Some((constant_false, target_block_false))) if target_block_true == target_block_false => { + return Some((block, value, *target_block_true, *cond_value, *constant_true, *constant_false)); + }, + _ => {}, + } + }, + _ => {}, + } + }, + }; + } + + None +} + +pub fn branchless( + context: &mut Context, + _: &AnalysisResults, + function: Function, +) -> Result { + let mut modified = false; + return Ok(false); + + loop { + if let Some((block, instr_val, target_block, cond_value, constant_true, constant_false)) = find_cbr(context, function) { + block.remove_instruction(context, instr_val); + + let one = ConstantContent::new_uint(context, 64, 1); + let one = Constant::unique(context, one); + let one = Value::new_constant(context, one); + let a = Value::new_constant(context, constant_true); + let b = Value::new_constant(context, constant_false); + + // c is a boolean (1 or 0) + // Can we use predication? + // x = c * a + (1 − c) * b + let c_times_a = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Mul, arg1: cond_value, arg2: a }); + let one_minus_c = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Sub, arg1: one, arg2: cond_value }); + let one_minus_c_times_b = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Mul, arg1: one_minus_c, arg2: b }); + let x = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Add, arg1: c_times_a, arg2: one_minus_c_times_b }); + + block.insert_instructions_after(context, cond_value, [c_times_a, one_minus_c, one_minus_c_times_b, x]); + + let call_target_block = Value::new_instruction(context, block, InstOp::Branch(BranchToWithArgs { + block: target_block, + args: vec![x] + })); + + let block = &mut context.blocks[block.0]; + block.instructions.push(call_target_block); + + modified = true; + } else { + break; + } + } + + eprintln!("{}", context.to_string()); + + Ok(modified) +} + +#[cfg(test)] +mod tests { + use crate::tests::assert_optimization; + use super::BRANCHLESS_NAME; + + #[test] + fn branchless_optimized() { + let before_optimization = format!( + " + fn main(baba !68: u64) -> u64, !71 {{ + entry(baba: u64): + v0 = const u64 0, !72 + cbr v0, block0(), block1(), !73 + + block0(): + v2 = const u64 1, !76 + br block2(v2) + + block1(): + v3 = const u64 2, !77 + br block2(v3) + + block2(v4: u64): + ret u64 v4 + }} +", + ); + assert_optimization(&[BRANCHLESS_NAME], &before_optimization, Some(["const u64 1, !76"])); + } +} \ No newline at end of file diff --git a/sway-ir/src/pass_manager.rs b/sway-ir/src/pass_manager.rs index 0bcdd218f58..0029710f602 100644 --- a/sway-ir/src/pass_manager.rs +++ b/sway-ir/src/pass_manager.rs @@ -1,15 +1,5 @@ use crate::{ - create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass, - create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass, - create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass, - create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass, - create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass, - create_module_printer_pass, create_module_verifier_pass, create_postorder_pass, - create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function, - IrError, Module, ARG_DEMOTION_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, - CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, - GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, - SIMPLIFY_CFG_NAME, SROA_NAME, + create_arg_demotion_pass, create_branchless, create_ccp_pass, create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass, create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function, IrError, Module, ARG_DEMOTION_NAME, BRANCHLESS_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME }; use downcast_rs::{impl_downcast, Downcast}; use rustc_hash::FxHashMap; @@ -164,7 +154,8 @@ pub struct PassManager { } impl PassManager { - pub const OPTIMIZATION_PASSES: [&'static str; 14] = [ + pub const OPTIMIZATION_PASSES: [&'static str; 15] = [ + BRANCHLESS_NAME, FN_INLINE_NAME, SIMPLIFY_CFG_NAME, SROA_NAME, @@ -395,6 +386,7 @@ pub fn register_known_passes(pm: &mut PassManager) { pm.register(create_fn_dedup_debug_profile_pass()); pm.register(create_mem2reg_pass()); pm.register(create_sroa_pass()); + pm.register(create_branchless()); pm.register(create_fn_inline_pass()); pm.register(create_const_folding_pass()); pm.register(create_ccp_pass()); @@ -415,6 +407,7 @@ pub fn create_o1_pass_group() -> PassGroup { // Configure to run our passes. o1.append_pass(MEM2REG_NAME); o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME); + o1.append_pass(BRANCHLESS_NAME); o1.append_pass(FN_INLINE_NAME); o1.append_pass(SIMPLIFY_CFG_NAME); o1.append_pass(GLOBALS_DCE_NAME); diff --git a/test/src/e2e_vm_tests/mod.rs b/test/src/e2e_vm_tests/mod.rs index b9797513678..e7235fa1de2 100644 --- a/test/src/e2e_vm_tests/mod.rs +++ b/test/src/e2e_vm_tests/mod.rs @@ -278,6 +278,11 @@ fn print_receipts(output: &mut String, receipts: &[Receipt]) { } } +struct RunResult { + size: Option, + gas: Option, +} + impl TestContext { async fn deploy_contract( &self, @@ -306,7 +311,7 @@ impl TestContext { }) } - async fn run(&self, test: TestDescription, output: &mut String, verbose: bool) -> Result<()> { + async fn run(&self, test: TestDescription, output: &mut String, verbose: bool) -> Result { let TestDescription { name, suffix, @@ -342,6 +347,11 @@ impl TestContext { expected_result }; + let mut r = RunResult { + size: None, + gas: None, + }; + match category { TestCategory::Runs => { let expected_result = expected_result.expect("No expected result found. This is likely because test.toml is missing either an \"expected_result_new_encoding\" or \"expected_result\" entry"); @@ -358,6 +368,7 @@ impl TestContext { for p in packages { let bytecode_len = p.bytecode.bytes.len(); + r.size = Some(bytecode_len as u64); let configurables = match &p.program_abi { sway_core::asm_generation::ProgramABI::Fuel(abi) => { @@ -408,6 +419,13 @@ impl TestContext { harness::VMExecutionResult::Fuel(state, receipts, ecal) => { print_receipts(output, &receipts); + if let Some(gas_used) = receipts.iter().filter_map(|x| match x { + Receipt::ScriptResult { gas_used, .. } => Some(*gas_used), + _ => None + }).last() { + r.gas = Some(gas_used); + } + use std::fmt::Write; let _ = writeln!(output, " {}", "Captured Output".green().bold()); for captured in ecal.captured.iter() { @@ -483,7 +501,7 @@ impl TestContext { output.push_str(&out); result?; } - Ok(()) + Ok(r) } } @@ -543,7 +561,7 @@ impl TestContext { output.push_str(&out); } } - Ok(()) + Ok(r) } TestCategory::FailsToCompile => { @@ -560,7 +578,7 @@ impl TestContext { Err(anyhow::Error::msg("Test compiles but is expected to fail")) } else { check_file_checker(checker, &name, output)?; - Ok(()) + Ok(r) } } @@ -654,7 +672,7 @@ impl TestContext { _ => {} }; - Ok(()) + Ok(r) } TestCategory::UnitTestsPass => { @@ -729,6 +747,8 @@ impl TestContext { decoded_logs, expected_decoded_test_logs ); } + + r }) } @@ -833,18 +853,29 @@ pub async fn run(filter_config: &FilterConfig, run_config: &RunConfig) -> Result context.run(test, &mut output, run_config.verbose).await }; - if let Err(err) = result { - println!(" {}", "failed".red().bold()); - println!("{}", textwrap::indent(err.to_string().as_str(), " ")); - println!("{}", textwrap::indent(&output, " ")); - number_of_tests_failed += 1; - failed_tests.push(name); - } else { - println!(" {}", "ok".green().bold()); + match result { + Err(err) => { + println!(" {}", "failed".red().bold()); + println!("{}", textwrap::indent(err.to_string().as_str(), " ")); + println!("{}", textwrap::indent(&output, " ")); + number_of_tests_failed += 1; + failed_tests.push(name); + } + Ok(r) => { + if let Some(size) = r.size { + print!(" {} bytes ", size); + } - // If verbosity is requested then print it out. - if run_config.verbose && !output.is_empty() { - println!("{}", textwrap::indent(&output, " ")); + if let Some(gas) = r.gas { + print!(" {} gas used ", gas); + } + + println!(" {}", "ok".green().bold()); + + // If verbosity is requested then print it out. + if run_config.verbose && !output.is_empty() { + println!("{}", textwrap::indent(&output, " ")); + } } } diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/language/main_args/main_args_one_u64/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/language/main_args/main_args_one_u64/src/main.sw index 01929765217..4074e6f99c9 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/language/main_args/main_args_one_u64/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/language/main_args/main_args_one_u64/src/main.sw @@ -1,5 +1,14 @@ script; +#[inline(never)] +fn f(baba:u64) -> u64 { + if baba == 0 { + 1 + } else { + 2 + } +} + fn main(baba: u64) -> u64 { - baba + 1 + f(baba) }