From 1a2b17392101b4c52c58be0c500167e7a543ad0a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Nov 2025 13:44:15 -0800 Subject: [PATCH 1/5] Change for loops from min/extent to min/max --- src/AddImageChecks.cpp | 2 +- src/AsyncProducers.cpp | 12 +++---- src/BoundConstantExtentLoops.cpp | 12 +++---- src/BoundSmallAllocations.cpp | 2 +- src/Bounds.cpp | 8 ++--- src/BoundsInference.cpp | 4 +-- src/CanonicalizeGPUVars.cpp | 8 ++--- src/Closure.cpp | 2 +- src/CodeGen_C.cpp | 5 ++- src/CodeGen_D3D12Compute_Dev.cpp | 3 +- src/CodeGen_Hexagon.cpp | 4 +-- src/CodeGen_LLVM.cpp | 10 +++--- src/CodeGen_Vulkan_Dev.cpp | 17 +++++----- src/CodeGen_WebGPU_Dev.cpp | 4 +-- src/Deserialization.cpp | 4 +-- src/EarlyFree.cpp | 2 +- src/FuseGPUThreadLoops.cpp | 42 ++++++++++++------------ src/HexagonOffload.cpp | 5 +-- src/IR.cpp | 8 ++--- src/IR.h | 30 +++++++++-------- src/IREquality.cpp | 2 +- src/IRMutator.cpp | 6 ++-- src/IRPrinter.cpp | 2 +- src/IRVisitor.cpp | 4 +-- src/LICM.cpp | 6 ++-- src/LoopCarry.cpp | 6 ++-- src/LowerParallelTasks.cpp | 10 +++--- src/LowerWarpShuffles.cpp | 26 ++++++++------- src/OffloadGPULoops.cpp | 4 +-- src/PartitionLoops.cpp | 30 ++++++++--------- src/Prefetch.cpp | 6 ++-- src/PrintLoopNest.cpp | 13 ++++---- src/Profiling.cpp | 2 +- src/RebaseLoopsToZero.cpp | 2 +- src/RemoveUndef.cpp | 8 ++--- src/ScheduleFunctions.cpp | 38 +++++++++++----------- src/SelectGPUAPI.cpp | 2 +- src/Serialization.cpp | 4 +-- src/Simplify_Stmts.cpp | 47 +++++++++++++++------------ src/SkipStages.cpp | 2 +- src/SlidingWindow.cpp | 28 ++++++++-------- src/StageStridedLoads.cpp | 2 +- src/StmtToHTML.cpp | 6 ++-- src/StorageFlattening.cpp | 6 ++-- src/StorageFolding.cpp | 8 ++--- src/Substitute.cpp | 6 ++-- src/TrimNoOps.cpp | 29 +++++++---------- src/UniquifyVariableNames.cpp | 8 ++--- src/UnrollLoops.cpp | 3 +- src/VectorizeLoops.cpp | 31 +++++++++++------- src/halide_ir.fbs | 6 ++-- test/correctness/fuse_gpu_threads.cpp | 4 +-- test/correctness/out_constraint.cpp | 4 +-- test/correctness/simplify.cpp | 8 ++--- 54 files changed, 279 insertions(+), 274 deletions(-) diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index b24626ad66a1..7114ad360135 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -36,7 +36,7 @@ class FindBuffers : public IRGraphVisitor { void visit(const For *op) override { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); bool old = in_device_loop; if (op->device_api != DeviceAPI::None && op->device_api != DeviceAPI::Host) { diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index bb5e4279d367..536327ebba3f 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator { if (is_no_op(body)) { return body; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -752,11 +752,10 @@ class InjectRingBuffering : public IRMutator { struct Loop { std::string name; - Expr min; Expr extent; - Loop(std::string n, Expr m, Expr e) - : name(std::move(n)), min(std::move(m)), extent(std::move(e)) { + Loop(std::string n, Expr e) + : name(std::move(n)), extent(std::move(e)) { } }; @@ -778,8 +777,7 @@ class InjectRingBuffering : public IRMutator { int loop_index = hoist_storage_loop_index[op->name] + 1; Expr current_index = Variable::make(Int(32), loops[loop_index].name); while (++loop_index < (int)loops.size()) { - current_index = current_index * - (loops[loop_index].extent - loops[loop_index].min) + + current_index = current_index * loops[loop_index].extent + Variable::make(Int(32), loops[loop_index].name); } current_index = current_index % f.schedule().ring_buffer(); @@ -817,7 +815,7 @@ class InjectRingBuffering : public IRMutator { } Stmt visit(const For *op) override { - loops.emplace_back(op->name, op->min, op->extent); + loops.emplace_back(op->name, op->extent()); Stmt mutated = IRMutator::visit(op); loops.pop_back(); return mutated; diff --git a/src/BoundConstantExtentLoops.cpp b/src/BoundConstantExtentLoops.cpp index fc2fea5b9d41..1312508f8ddb 100644 --- a/src/BoundConstantExtentLoops.cpp +++ b/src/BoundConstantExtentLoops.cpp @@ -46,7 +46,8 @@ class BoundLoops : public IRMutator { } Stmt visit(const For *op) override { - if (is_const(op->extent)) { + Expr extent = simplify(op->extent()); + if (is_const(extent)) { // Nothing needs to be done return IRMutator::visit(op); } @@ -54,7 +55,6 @@ class BoundLoops : public IRMutator { if (op->for_type == ForType::Unrolled || op->for_type == ForType::Vectorized) { // Give it one last chance to simplify to an int - Expr extent = simplify(op->extent); Stmt body = op->body; const IntImm *e = extent.as(); @@ -82,8 +82,8 @@ class BoundLoops : public IRMutator { if (extent_upper.defined()) { e = extent_upper.as(); body = - IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) < - op->min + op->extent), + IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <= + op->max), body); } } @@ -93,7 +93,7 @@ class BoundLoops : public IRMutator { // to a serial loop. user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n"; body = mutate(body); - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, ForType::Serial, op->partition_policy, op->device_api, std::move(body)); } @@ -103,7 +103,7 @@ class BoundLoops : public IRMutator { << "Loop over " << op->name << " has extent " << extent << ".\n"; body = mutate(body); - return For::make(op->name, op->min, e, + return For::make(op->name, op->min, (op->min + e) - 1, op->for_type, op->partition_policy, op->device_api, std::move(body)); } else { return IRMutator::visit(op); diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index 80f58889448c..f3347c0f47fd 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -59,7 +59,7 @@ class BoundSmallAllocations : public IRMutator { Stmt visit(const For *op) override { Interval min_bounds = find_constant_bounds(op->min, scope); - Interval max_bounds = find_constant_bounds(op->min + op->extent - 1, scope); + Interval max_bounds = find_constant_bounds(op->max, scope); Interval b = Interval::make_union(min_bounds, max_bounds); b.min = simplify(b.min); b.max = simplify(b.max); diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 670ccf11d177..8079d7c980e4 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -2964,7 +2964,7 @@ class BoxesTouched : public IRGraphVisitor { TRACK_BOXES_TOUCHED_INFO("var:", op->name); if (consider_calls) { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); } Expr min_val, max_val; @@ -2977,9 +2977,7 @@ class BoxesTouched : public IRGraphVisitor { if (const Interval *in = scope.find(op->name + ".loop_max")) { max_val = in->max; } else { - max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max; - max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max; - max_val -= 1; + max_val = bounds_of_expr_in_scope(op->max, scope, func_bounds).max; } push_var(op->name); @@ -3819,7 +3817,7 @@ void bounds_test() { Buffer in(10); in.set_name("input"); - Stmt loop = For::make("x", 3, 10, ForType::Serial, Partition::Auto, DeviceAPI::Host, + Stmt loop = For::make("x", 3, 12, ForType::Serial, Partition::Auto, DeviceAPI::Host, Provide::make("output", {Add::make(Call::make(in, input_site_1), Call::make(in, input_site_2))}, diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 9ba1f1af2019..6a0ffc3f8d19 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -1308,7 +1308,7 @@ class BoundsInference : public IRMutator { } } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } Scope<> let_vars_in_scope; @@ -1392,7 +1392,7 @@ Stmt bounds_inference(Stmt s, s = Block::make(Evaluate::make(marker), s); // Add a synthetic outermost loop to act as 'root'. - s = For::make("", 0, 1, ForType::Serial, Partition::Never, DeviceAPI::None, s); + s = For::make("", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s); s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups, outputs, func_bounds, target) diff --git a/src/CanonicalizeGPUVars.cpp b/src/CanonicalizeGPUVars.cpp index 4e70af965138..49f1e37d2945 100644 --- a/src/CanonicalizeGPUVars.cpp +++ b/src/CanonicalizeGPUVars.cpp @@ -105,7 +105,7 @@ class CanonicalizeGPUVars : public IRMutator { Stmt visit(const For *op) override { std::string name = op->name; Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = mutate(op->body); if ((op->for_type == ForType::GPUBlock) || @@ -130,18 +130,18 @@ class CanonicalizeGPUVars : public IRMutator { gpu_vars.emplace(op->name, name); Expr new_var = Variable::make(Int(32), name); min = substitute(op->name, new_var, min); - extent = substitute(op->name, new_var, extent); + max = substitute(op->name, new_var, max); body = substitute(op->name, new_var, body); } } if ((name == op->name) && min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } else { - return For::make(name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } diff --git a/src/Closure.cpp b/src/Closure.cpp index 5c5125a9b291..de1564526462 100644 --- a/src/Closure.cpp +++ b/src/Closure.cpp @@ -38,7 +38,7 @@ void Closure::visit(const LetStmt *op) { void Closure::visit(const For *op) { ScopedBinding<> p(ignore, op->name); op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); op->body.accept(this); } diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 44756204745a..a5dc3298be63 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2223,7 +2223,7 @@ void CodeGen_C::visit(const Atomic *op) { void CodeGen_C::visit(const For *op) { string id_min = print_expr(op->min); - string id_extent = print_expr(op->extent); + string id_max = print_expr(op->max); if (op->for_type == ForType::Parallel) { stream << get_indent() << "#pragma omp parallel for\n"; @@ -2237,8 +2237,7 @@ void CodeGen_C::visit(const For *op) { << " = " << id_min << "; " << print_name(op->name) - << " < " << id_min - << " + " << id_extent + << " <= " << id_max << "; " << print_name(op->name) << "++)\n"; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index ad4f6451f918..4ce641e680ad 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1179,7 +1179,8 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, // generation time, emit code such that it can be patched some point // later when calling D3DCompile() / halide_d3d12compute_run() numthreads[index] = 0; // <-- 0 indicates 'undetermined' - const IntImm *int_limit = loop->extent.as(); + Expr extent = simplify(loop->extent()); + const IntImm *int_limit = extent.as(); if (nullptr != int_limit) { numthreads[index] = int_limit->value; user_assert(numthreads[index] > 0) << "For D3D12Compute, 'numthreads[" << index << "]' values must be greater than zero.\n"; diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 8515b020ea64..352f9c24afdb 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -363,7 +363,7 @@ class InjectHVXLocks : public IRMutator { if (uses_hvx) { body = acquire_hvx_context(body, target); body = substitute("uses_hvx", true, body); - Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type, + Stmt new_for = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); Stmt prolog = IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock()); @@ -408,7 +408,7 @@ class InjectHVXLocks : public IRMutator { // vector code // halide_qurt_unlock // } - s = For::make(op->name, op->min, op->extent, op->for_type, + s = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 83de58c83a07..4506172d5237 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3685,7 +3685,7 @@ void CodeGen_LLVM::visit(const ProducerConsumer *op) { void CodeGen_LLVM::visit(const For *op) { Value *min = codegen(op->min); - Value *extent = codegen(op->extent); + Value *max = codegen(op->max); const Acquire *acquire = op->body.as(); // TODO(zvookin): remove this after validating it doesn't happen @@ -3696,8 +3696,6 @@ void CodeGen_LLVM::visit(const For *op) { if (op->for_type == ForType::Serial) { - Value *max = builder->CreateNSWAdd(min, extent); - BasicBlock *preheader_bb = builder->GetInsertBlock(); // Make a new basic block for the loop @@ -3708,8 +3706,8 @@ void CodeGen_LLVM::visit(const For *op) { BasicBlock *after_bb = BasicBlock::Create( *context, std::to_string(for_loop_id) + std::string("_end_for_") + op->name, function); - // If min < max, fall through to the loop bb - Value *enter_condition = builder->CreateICmpSLT(min, max); + // If min <= max, fall through to the loop bb + Value *enter_condition = builder->CreateICmpSLE(min, max); builder->CreateCondBr(enter_condition, loop_bb, after_bb, very_likely_branch); builder->SetInsertPoint(loop_bb); @@ -3730,7 +3728,7 @@ void CodeGen_LLVM::visit(const For *op) { phi->addIncoming(next_var, builder->GetInsertBlock()); // Maybe exit the loop - Value *end_condition = builder->CreateICmpNE(next_var, max); + Value *end_condition = builder->CreateICmpSLE(next_var, max); builder->CreateCondBr(end_condition, loop_bb, after_bb); builder->SetInsertPoint(after_bb); diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 26aff3cefa13..6bd33d6bc6b8 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -418,7 +418,8 @@ struct FindWorkGroupSize : public IRVisitor { // Save & validate the workgroup size int index = thread_loop_workgroup_index(loop->name); if (index >= 0) { - const IntImm *literal = loop->extent.as(); + Expr extent = simplify(loop->extent()); + const IntImm *literal = extent.as(); if (literal != nullptr) { uint32_t new_wg_size = literal->value; user_assert(workgroup_size[index] == 0 || workgroup_size[index] == new_wg_size) @@ -1683,7 +1684,7 @@ std::pair simt_intrinsic(const std::string &name) { } // anonymous namespace void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { - debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(For): name=" << op->name << " min=" << op->min << " extent=" << op->extent << "\n"; + debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(For): name=" << op->name << " min=" << op->min << " max=" << op->max << "\n"; if (is_gpu(op->for_type)) { // This should always be true at this point in codegen @@ -1710,24 +1711,22 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { } } else { - debug(2) << " (serial for loop): min=" << op->min << " extent=" << op->extent << "\n"; + debug(2) << " (serial for loop): min=" << op->min << " max=" << op->max << "\n"; internal_assert(op->for_type == ForType::Serial) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit unhandled For type: " << op->for_type << "\n"; - user_assert(op->min.type() == op->extent.type()); + user_assert(op->min.type() == op->max.type()); user_assert(op->min.type().is_int() || op->min.type().is_uint()); op->min.accept(this); SpvId min_id = builder.current_id(); - op->extent.accept(this); - SpvId extent_id = builder.current_id(); + op->max.accept(this); + SpvId max_id = builder.current_id(); // Compute max. Type index_type = op->min.type(); SpvId index_type_id = builder.declare_type(index_type); SpvStorageClass storage_class = SpvStorageClassFunction; SpvId index_var_type_id = builder.declare_pointer_type(index_type_id, storage_class); - SpvId max_id = builder.reserve_id(SpvResultId); - builder.append(SpvFactory::integer_add(index_type_id, max_id, min_id, extent_id)); // Declare loop var const std::string loop_var_name = unique_name(std::string("k") + std::to_string(kernel_index) + "_loop_idx"); @@ -1757,7 +1756,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { SpvId loop_test_type_id = builder.declare_type(Bool()); SpvId loop_test_id = builder.reserve_id(SpvResultId); builder.append(SpvFactory::load(index_type_id, loop_index_id, loop_var_id)); - builder.append(SpvFactory::integer_less_than(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_uint())); + builder.append(SpvFactory::integer_less_than_equal(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_uint())); builder.append(SpvFactory::conditional_branch(loop_test_id, body_block_id, merge_block_id)); } builder.leave_block(); diff --git a/src/CodeGen_WebGPU_Dev.cpp b/src/CodeGen_WebGPU_Dev.cpp index ea43b3cefbd5..d0a9310856ca 100644 --- a/src/CodeGen_WebGPU_Dev.cpp +++ b/src/CodeGen_WebGPU_Dev.cpp @@ -663,11 +663,11 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const For *loop) { << "Can only use serial loops inside WebGPU shaders\n"; string id_min = print_expr(loop->min); - string id_extent = print_expr(loop->extent); + string id_max = print_expr(loop->max); string id_counter = print_name(loop->name); stream << get_indent() << "for (var " << id_counter << " = " << id_min << "; " - << id_counter << " < " << id_min << " + " << id_extent << "; " + << id_counter << " <= " << id_max << "; " // TODO: Use increment statement when supported by Chromium. << id_counter << " = " << id_counter << " + 1)\n"; open_scope(); diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b6f49cb1bf43..1e3ee74f9f65 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -539,12 +539,12 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt) const auto *for_stmt = (const Serialize::For *)stmt; const auto name = deserialize_string(for_stmt->name()); const auto min = deserialize_expr(for_stmt->min_type(), for_stmt->min()); - const auto extent = deserialize_expr(for_stmt->extent_type(), for_stmt->extent()); + const auto max = deserialize_expr(for_stmt->max_type(), for_stmt->max()); const ForType for_type = deserialize_for_type(for_stmt->for_type()); const Partition partition_policy = deserialize_partition(for_stmt->partition_policy()); const DeviceAPI device_api = deserialize_device_api(for_stmt->device_api()); const auto body = deserialize_stmt(for_stmt->body_type(), for_stmt->body()); - return For::make(name, min, extent, for_type, partition_policy, device_api, body); + return For::make(name, min, max, for_type, partition_policy, device_api, body); } case Serialize::Stmt::Store: { const auto *store_stmt = (const Serialize::Store *)stmt; diff --git a/src/EarlyFree.cpp b/src/EarlyFree.cpp index 35de3c15cbcd..8b664c2bcf8d 100644 --- a/src/EarlyFree.cpp +++ b/src/EarlyFree.cpp @@ -30,7 +30,7 @@ class FindLastUse : public IRVisitor { void visit(const For *loop) override { loop->min.accept(this); - loop->extent.accept(this); + loop->max.accept(this); ScopedValue old_in_loop(in_loop, true); loop->body.accept(this); } diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 1b7506d96c9b..88f9a542550f 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -39,7 +39,7 @@ class ExtractBlockSize : public IRVisitor { void found_thread_for(int dim, const string &name, const Expr &extent) { internal_assert(dim >= 0 && dim < 3); if (!block_extent[dim].defined()) { - block_extent[dim] = extent; + block_extent[dim] = simplify(extent); } else { block_extent[dim] = simplify(Max::make(extent, block_extent[dim])); } @@ -55,16 +55,16 @@ class ExtractBlockSize : public IRVisitor { void visit(const For *op) override { for (int i = 0; i < 3; i++) { if (ends_with(op->name, gpu_thread_name(i))) { - found_thread_for(i, op->name, op->extent); + found_thread_for(i, op->name, op->extent()); } else if (ends_with(op->name, gpu_block_name(i))) { - found_block_for(i, op->name, op->extent); + found_block_for(i, op->name, op->extent()); } } IRVisitor::visit(op); Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); // For non-rectangular thread loops, use a bounding box. We'll inject if statements later. for (Expr &e : block_extent) { if (e.defined() && expr_uses_var(e, op->name)) { @@ -141,7 +141,7 @@ class NormalizeDimensionality : public IRMutator { return s; } while (max_depth < block_size.threads_dimensions()) { - s = For::make(gpu_thread_name(max_depth), 0, 1, ForType::GPUThread, + s = For::make(gpu_thread_name(max_depth), 0, 0, ForType::GPUThread, Partition::Never, device_api, s); max_depth++; } @@ -205,10 +205,10 @@ class ReplaceForWithIf : public IRMutator { Expr var = Variable::make(Int(32), gpu_thread_name(dim)); body = substitute(op->name, var + op->min, body); - if (equal(op->extent, block_size.num_threads(dim))) { + if (can_prove(op->extent() == block_size.num_threads(dim))) { return body; } else { - Expr cond = var < op->extent; + Expr cond = var <= op->max; return IfThenElse::make(cond, body, Stmt()); } } else { @@ -340,7 +340,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { // Expand any new shared allocations found in the body using the loop bounds. Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); for (SharedAllocation &s : allocations) { // If the size depends on the loop variable, take the max // over all loop iterations @@ -366,7 +366,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { precompute_allocation_size(s); break; case Monotonic::Increasing: - s.size = substitute(op->name, simplify(op->min + op->extent - 1), s.size); + s.size = substitute(op->name, op->max, s.size); break; case Monotonic::Constant: // The size expression used the variable, but we @@ -381,7 +381,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { } if (in_threads && op->is_parallel()) { // For parallel inner loops, make a separate slice per loop iteration - s.size *= op->extent; + s.size *= op->extent(); } } @@ -393,13 +393,13 @@ class ExtractSharedAndHeapAllocations : public IRMutator { } Expr new_min = mutate(op->min); - Expr new_extent = mutate(op->extent); + Expr new_max = mutate(op->max); if (host_side_preamble.defined()) { string loop_name = unique_name('t'); Expr v = Variable::make(Int(32), loop_name); host_side_preamble = substitute(op->name, v, host_side_preamble); - host_side_preamble = For::make(loop_name, new_min, new_extent, + host_side_preamble = For::make(loop_name, new_min, new_max, ForType::Serial, Partition::Never, DeviceAPI::None, host_side_preamble); if (old_preamble.defined()) { host_side_preamble = Block::make(old_preamble, host_side_preamble); @@ -408,7 +408,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { host_side_preamble = old_preamble; } - return For::make(op->name, new_min, new_extent, + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, body); } @@ -1082,7 +1082,7 @@ class ExtractRegisterAllocations : public IRMutator { // Expand any new register allocations found in the body using the loop bounds. Scope scope; - scope.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + scope.push(op->name, Interval(op->min, op->max)); // Expand the inner allocations using the loop bounds. for (RegisterAllocation &s : allocations) { @@ -1098,7 +1098,7 @@ class ExtractRegisterAllocations : public IRMutator { allocations.swap(old); } - return For::make(op->name, mutate(op->min), mutate(op->extent), op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, mutate(op->min), mutate(op->max), op->for_type, op->partition_policy, op->device_api, body); } } @@ -1258,7 +1258,7 @@ class InjectThreadBarriers : public IRMutator { // synchronizations within the block body = Block::make(body, make_barrier(0)); } - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } else { return IRMutator::visit(op); @@ -1410,13 +1410,13 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { string thread_id = gpu_thread_name(0); // Add back in any register-level allocations body = register_allocs.rewrap(body, thread_id); - body = For::make(thread_id, 0, block_size_x, innermost_loop_type, op->partition_policy, op->device_api, body); + body = For::make(thread_id, 0, block_size_x - 1, innermost_loop_type, op->partition_policy, op->device_api, body); // Rewrap the whole thing in other loops over threads for (int i = 1; i < block_size.threads_dimensions(); i++) { thread_id = gpu_thread_name(i); body = register_allocs.rewrap(body, thread_id); - body = For::make(thread_id, 0, block_size.num_threads(i), + body = For::make(thread_id, 0, block_size.num_threads(i) - 1, ForType::GPUThread, op->partition_policy, op->device_api, body); } thread_id.clear(); @@ -1433,7 +1433,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } else { return IRMutator::visit(op); @@ -1503,7 +1503,7 @@ class ZeroGPULoopMins : public IRMutator { internal_assert(op); Expr adjusted = Variable::make(Int(32), op->name) + op->min; Stmt body = substitute(op->name, adjusted, op->body); - stmt = For::make(op->name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, 0, simplify(op->max - op->min), op->for_type, op->partition_policy, op->device_api, body); } return stmt; } @@ -1547,7 +1547,7 @@ class AddConditionToALoop : public IRMutator { return IRMutator::visit(op); } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, IfThenElse::make(condition, op->body, Stmt())); } diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index a7a305b5902e..e540a9697513 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -4,6 +4,7 @@ #include "Closure.h" #include "Elf.h" #include "HexagonOffload.h" +#include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "InjectHostDevBufferCopies.h" @@ -745,10 +746,10 @@ class InjectHexagonRpc : public IRMutator { // After moving this to Hexagon, it doesn't need to be marked // Hexagon anymore. Stmt body; - if (is_const_one(loop->extent)) { + if (equal(loop->min, loop->max)) { body = LetStmt::make(loop->name, loop->min, loop->body); } else { - body = For::make(loop->name, loop->min, loop->extent, loop->for_type, loop->partition_policy, + body = For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, DeviceAPI::None, loop->body); } diff --git a/src/IR.cpp b/src/IR.cpp index e17818a8cc9e..c844c672656a 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -342,20 +342,20 @@ Stmt ProducerConsumer::make_consume(const std::string &name, Stmt body) { } Stmt For::make(const std::string &name, - Expr min, Expr extent, + Expr min, Expr max, ForType for_type, Partition partition_policy, DeviceAPI device_api, Stmt body) { internal_assert(min.defined()) << "For of undefined\n"; - internal_assert(extent.defined()) << "For of undefined\n"; + internal_assert(max.defined()) << "For of undefined\n"; internal_assert(min.type() == Int(32)) << "For with non-integer min\n"; - internal_assert(extent.type() == Int(32)) << "For with non-integer extent\n"; + internal_assert(max.type() == Int(32)) << "For with non-integer max\n"; internal_assert(body.defined()) << "For of undefined\n"; For *node = new For; node->name = name; node->min = std::move(min); - node->extent = std::move(extent); + node->max = std::move(max); node->for_type = for_type; node->partition_policy = partition_policy; node->device_api = device_api; diff --git a/src/IR.h b/src/IR.h index bdf42a75f7b1..53ae316a404b 100644 --- a/src/IR.h +++ b/src/IR.h @@ -833,28 +833,26 @@ struct Variable : public ExprNode { static const IRNodeType _node_type = IRNodeType::Variable; }; -/** A for loop. Execute the 'body' statement for all values of the - * variable 'name' from 'min' to 'min + extent'. There are four - * types of For nodes. A 'Serial' for loop is a conventional - * one. In a 'Parallel' for loop, each iteration of the loop - * happens in parallel or in some unspecified order. In a - * 'Vectorized' for loop, each iteration maps to one SIMD lane, - * and the whole loop is executed in one shot. For this case, - * 'extent' must be some small integer constant (probably 4, 8, or - * 16). An 'Unrolled' for loop compiles to a completely unrolled - * version of the loop. Each iteration becomes its own - * statement. Again in this case, 'extent' should be a small - * integer constant. */ +/** A for loop. Execute the 'body' statement for all values of the variable + * 'name' from 'min' to 'max' inclusive. There are four types of For nodes. A + * 'Serial' for loop is a conventional one. In a 'Parallel' for loop, each + * iteration of the loop happens in parallel or in some unspecified order. In a + * 'Vectorized' for loop, each iteration maps to one SIMD lane, and the whole + * loop is executed in one shot. For this case, the extent (max - min + 1) must + * be some small integer constant (probably 4, 8, or 16). An 'Unrolled' for loop + * compiles to a completely unrolled version of the loop. Each iteration becomes + * its own statement. Again in this case, the extent should be a small integer + * constant. */ struct For : public StmtNode { std::string name; - Expr min, extent; + Expr min, max; ForType for_type; DeviceAPI device_api; Stmt body; Partition partition_policy; static Stmt make(const std::string &name, - Expr min, Expr extent, + Expr min, Expr max, ForType for_type, Partition partition_policy, DeviceAPI device_api, Stmt body); @@ -866,6 +864,10 @@ struct For : public StmtNode { return Halide::Internal::is_parallel(for_type); } + Expr extent() const { + return Add::make(Sub::make(max, min), 1); + } + static const IRNodeType _node_type = IRNodeType::For; }; diff --git a/src/IREquality.cpp b/src/IREquality.cpp index a8dd27d7c7df..c85041ff2460 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -426,7 +426,7 @@ struct Comparer { cmp(&For::device_api); cmp(&For::partition_policy); cmp(&For::min); - cmp(&For::extent); + cmp(&For::max); cmp(&For::body); break; case IRNodeType::Acquire: diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index f0b861f1d5c0..9eecd0579840 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -202,14 +202,14 @@ Stmt IRMutator::visit(const ProducerConsumer *op) { Stmt IRMutator::visit(const For *op) { Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = mutate(op->body); if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } - return For::make(op->name, std::move(min), std::move(extent), + return For::make(op->name, std::move(min), std::move(max), op->for_type, op->partition_policy, op->device_api, std::move(body)); } diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 8dde0795181d..86549f3a5d1a 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1131,7 +1131,7 @@ void IRPrinter::visit(const For *op) { stream << var(op->name) << paren(", "); print_no_parens(op->min); stream << paren(", "); - print_no_parens(op->extent); + print_no_parens(op->max); closef(); stream << " "; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index 9a5e6a8e0537..25fd7e608f27 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -167,7 +167,7 @@ void IRVisitor::visit(const ProducerConsumer *op) { void IRVisitor::visit(const For *op) { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); op->body.accept(this); } @@ -443,7 +443,7 @@ void IRGraphVisitor::visit(const ProducerConsumer *op) { void IRGraphVisitor::visit(const For *op) { include(op->min); - include(op->extent); + include(op->max); include(op->body); } diff --git a/src/LICM.cpp b/src/LICM.cpp index c73fcf5424ab..31f079175fad 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -318,7 +318,7 @@ class LICM : public IRMutator { const For *loop = new_stmt.as(); internal_assert(loop); - new_stmt = For::make(loop->name, loop->min, loop->extent, + new_stmt = For::make(loop->name, loop->min, loop->max, loop->for_type, loop->partition_policy, loop->device_api, mutate(loop->body)); // Wrap lets for the lifted invariants @@ -563,7 +563,7 @@ class HoistIfStatements : public IRMutator { if (!i->else_case.defined() && is_pure(i->condition) && !expr_uses_var(i->condition, op->name)) { - Stmt s = For::make(op->name, op->min, op->extent, + Stmt s = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, i->then_case); return IfThenElse::make(i->condition, s); } @@ -571,7 +571,7 @@ class HoistIfStatements : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 5349e9c316f9..fccc049ce5b6 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -550,7 +550,7 @@ class LoopCarry : public IRMutator { } Stmt visit(const For *op) override { - if (op->for_type == ForType::Serial && !is_const_one(op->extent)) { + if (op->for_type == ForType::Serial && !equal(op->min, op->max)) { Stmt stmt; Stmt body = mutate(op->body); LoopCarryOverLoop carry(op->name, in_consume, max_carried_values); @@ -558,7 +558,7 @@ class LoopCarry : public IRMutator { if (body.same_as(op->body)) { stmt = op; } else { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } // Inject the scratch buffer allocations. @@ -567,7 +567,7 @@ class LoopCarry : public IRMutator { stmt = Allocate::make(alloc.name, alloc.type, MemoryType::Stack, {alloc.size}, const_true(), stmt); } if (!carry.allocs.empty()) { - stmt = IfThenElse::make(op->extent > 0, stmt); + stmt = IfThenElse::make(op->min <= op->max, stmt); } return stmt; } else { diff --git a/src/LowerParallelTasks.cpp b/src/LowerParallelTasks.cpp index d6ed27ca0905..62e909136841 100644 --- a/src/LowerParallelTasks.cpp +++ b/src/LowerParallelTasks.cpp @@ -270,9 +270,11 @@ struct LowerParallelTasks : public IRMutator { std::string loop_min_name = unique_name('t'); std::string loop_extent_name = unique_name('t'); if (!t.loop_var.empty()) { + Expr min = Variable::make(Int(32), loop_min_name); + Expr extent = Variable::make(Int(32), loop_extent_name); t.body = For::make(t.loop_var, - Variable::make(Int(32), loop_min_name), - Variable::make(Int(32), loop_extent_name), + min, + min + extent - 1, ForType::Serial, t.partition_policy, DeviceAPI::None, @@ -380,7 +382,7 @@ struct LowerParallelTasks : public IRMutator { result.emplace_back(std::move(t)); } else if (loop && loop->for_type == ForType::Parallel) { add_suffix(prefix, ".par_for." + loop->name); - ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix), loop->partition_policy}; + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent(), const_false(), task_debug_name(prefix), loop->partition_policy}; result.emplace_back(std::move(t)); } else if (loop && loop->for_type == ForType::Serial && @@ -389,7 +391,7 @@ struct LowerParallelTasks : public IRMutator { const Variable *v = acquire->semaphore.as(); internal_assert(v); add_suffix(prefix, ".for." + v->name); - ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix), loop->partition_policy}; + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent(), const_true(), task_debug_name(prefix), loop->partition_policy}; while (acquire) { t.semaphores.push_back({acquire->semaphore, acquire->count}); t.body = acquire->body; diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index 2551fe0bffbb..e45f99f37e14 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -234,11 +234,11 @@ class DetermineAllocStride : public IRVisitor { void visit(const For *op) override { ScopedBinding - bind_bounds_if(is_const(op->min) && is_const(op->extent), - bounds, op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bind_bounds_if(is_const(op->min) && is_const(op->max), + bounds, op->name, Interval(op->min, op->max)); ScopedBinding bound_dependent_if((expr_uses_vars(op->min, dependent_vars) || - expr_uses_vars(op->extent, dependent_vars)), + expr_uses_vars(op->max, dependent_vars)), dependent_vars, op->name, Expr()); IRVisitor::visit(op); } @@ -372,16 +372,17 @@ class LowerWarpShuffles : public IRMutator { Stmt visit(const For *op) override { ScopedBinding - bind_if(is_const(op->min) && is_const(op->extent), - bounds, op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bind_if(is_const(op->min) && is_const(op->max), + bounds, op->name, Interval(op->min, op->max)); if (!this_lane.defined() && op->for_type == ForType::GPULane) { bool should_mask = false; ScopedValue old_warp_size(warp_size); + Expr extent = simplify(op->extent()); if (op->for_type == ForType::GPULane) { - auto loop_size = as_const_int(op->extent); + auto loop_size = as_const_int(extent); user_assert(loop_size && *loop_size <= 32) - << "CUDA gpu lanes loop must have constant extent of at most 32: " << op->extent << "\n"; + << "CUDA gpu lanes loop must have constant extent of at most 32: " << extent << "\n"; // Select a warp size - the smallest power of two that contains the loop size int64_t ws = 1; @@ -391,7 +392,7 @@ class LowerWarpShuffles : public IRMutator { should_mask = (ws != *loop_size); warp_size = make_const(Int(32), ws); } else { - warp_size = op->extent; + warp_size = extent; } this_lane_name = op->name; this_lane = Variable::make(Int(32), op->name); @@ -408,7 +409,8 @@ class LowerWarpShuffles : public IRMutator { // with storage striped across the warp lanes, so the // size required per-lane is the old size divided by // the number of lanes (rounded up). - Expr new_size = (alloc->extents[0] + op->extent - 1) / op->extent; + Expr extent = op->extent(); + Expr new_size = (alloc->extents[0] + extent - 1) / extent; new_size = simplify(new_size, true, bounds); new_size = find_constant_bound(new_size, Direction::Upper, bounds); auto sz = as_const_int(new_size); @@ -423,7 +425,7 @@ class LowerWarpShuffles : public IRMutator { if (should_mask) { // Mask off the excess lanes in the warp - body = IfThenElse::make(this_lane < op->extent, body, Stmt()); + body = IfThenElse::make(this_lane <= op->max, body, Stmt()); } // Wrap the hoisted warp-level allocations, at their new @@ -455,7 +457,7 @@ class LowerWarpShuffles : public IRMutator { } allocations.clear(); - return For::make(op->name, op->min, warp_size, + return For::make(op->name, op->min, op->min + warp_size - 1, op->for_type, op->partition_policy, op->device_api, body); } else { return IRMutator::visit(op); @@ -732,7 +734,7 @@ class HoistWarpShufflesFromSingleIfStmt : public IRMutator { } else { debug(3) << "Successfully hoisted shuffle out of for loop\n"; } - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } Stmt visit(const Store *op) override { diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index 4a33c8f1bc00..aa3e8422d3fd 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -55,10 +55,10 @@ class ExtractBounds : public IRVisitor { for (int i = 0; i < 3; i++) { if (ends_with(op->name, gpu_thread_name(i))) { - num_threads[i] = op->extent; + num_threads[i] = op->extent(); } if (ends_with(op->name, gpu_block_name(i))) { - num_blocks[i] = op->extent; + num_blocks[i] = op->extent(); } } diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index cc9bda77d586..1c77f014eb7b 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -406,7 +406,7 @@ class FindSimplifications : public IRVisitor { for (Simplification &s : simplifications) { if (expr_uses_var(s.condition, op->name)) { Scope varying; - varying.push(op->name, Interval(op->min, op->min + op->extent - 1)); + varying.push(op->name, Interval(op->min, op->max)); Expr relaxed = and_condition_over_domain(s.condition, varying); internal_assert(!expr_uses_var(relaxed, op->name)) << "Should not have had used the loop var (" << op->name @@ -721,7 +721,7 @@ class PartitionLoops : public IRMutator { } // Construct variables for the bounds of the simplified middle section - Expr min_steady = op->min, max_steady = op->extent + op->min; + Expr min_steady = op->min, max_steady = op->max + 1; Expr prologue_val, epilogue_val; string prologue_name = unique_name(op->name + ".prologue"); string epilogue_name = unique_name(op->name + ".epilogue"); @@ -735,7 +735,7 @@ class PartitionLoops : public IRMutator { min_vals.push_back(op->min); prologue_val = fold_left(min_vals, Max::make); // Stop the prologue from running past the end of the loop - prologue_val = min(prologue_val, op->extent + op->min); + prologue_val = min(prologue_val, op->max + 1); // prologue_val = print(prologue_val, prologue_name); min_steady = Variable::make(Int(32), prologue_name); @@ -743,7 +743,7 @@ class PartitionLoops : public IRMutator { } if (make_epilogue) { std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare()); - max_vals.push_back(op->min + op->extent - 1); + max_vals.push_back(op->max); epilogue_val = fold_left(max_vals, Min::make) + 1; // Stop the epilogue from running before the start of the loop/prologue if (make_prologue) { @@ -760,17 +760,17 @@ class PartitionLoops : public IRMutator { Stmt stmt; // Bust simple serial for loops up into three. if (op->for_type == ForType::Serial && !op->body.as()) { - stmt = For::make(op->name, min_steady, max_steady - min_steady, + stmt = For::make(op->name, min_steady, max_steady - 1, op->for_type, op->partition_policy, op->device_api, simpler_body); if (make_prologue) { - prologue = For::make(op->name, op->min, min_steady - op->min, + prologue = For::make(op->name, op->min, min_steady - 1, op->for_type, op->partition_policy, op->device_api, prologue); stmt = Block::make(prologue, stmt); mutated = true; } if (make_epilogue) { - epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady, + epilogue = For::make(op->name, max_steady, op->max, op->for_type, op->partition_policy, op->device_api, epilogue); stmt = Block::make(stmt, epilogue); mutated = true; @@ -803,19 +803,19 @@ class PartitionLoops : public IRMutator { mutated = true; } } - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, stmt); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, stmt); } if (make_epilogue) { // Uncomment to include code that prints the epilogue value - // epilogue_val = print(epilogue_val, op->name, "epilogue"); + // epilogue_val = print(epilogue_val, op->name, "epilogue", op->min, op->max); stmt = LetStmt::make(epilogue_name, epilogue_val, stmt); } else { - epilogue_val = op->min + op->extent; + epilogue_val = op->max + 1; } if (make_prologue) { // Uncomment to include code that prints the prologue value - // prologue_val = print(prologue_val, op->name, "prologue"); + // prologue_val = print(prologue_val, op->name, "prologue", op->min, op->max); stmt = LetStmt::make(prologue_name, prologue_val, stmt); } else { prologue_val = op->min; @@ -924,9 +924,9 @@ class RenormalizeGPULoops : public IRMutator { // Move lets in-between gpu loop levels inwards. if (f && in_gpu_loop && !in_thread_loop) { internal_assert(!expr_uses_var(f->min, op->name) && - !expr_uses_var(f->extent, op->name)); + !expr_uses_var(f->max, op->name)); Stmt inner = LetStmt::make(op->name, op->value, f->body); - inner = For::make(f->name, f->min, f->extent, f->for_type, f->partition_policy, f->device_api, inner); + inner = For::make(f->name, f->min, f->max, f->for_type, f->partition_policy, f->device_api, inner); return mutate(inner); } else if (a && in_gpu_loop && !in_thread_loop) { internal_assert(a->extents.size() == 1); @@ -1002,9 +1002,9 @@ class RenormalizeGPULoops : public IRMutator { } else if (for_a && for_b && for_a->name == for_b->name && for_a->min.same_as(for_b->min) && - for_a->extent.same_as(for_b->extent)) { + for_a->max.same_as(for_b->max)) { Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body); - inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->partition_policy, for_a->device_api, inner); + inner = For::make(for_a->name, for_a->min, for_a->max, for_a->for_type, for_a->partition_policy, for_a->device_api, inner); return mutate(inner); } else { internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n"; diff --git a/src/Prefetch.cpp b/src/Prefetch.cpp index a34e95f3c530..c0eedf50b817 100644 --- a/src/Prefetch.cpp +++ b/src/Prefetch.cpp @@ -246,7 +246,7 @@ class InjectPlaceholderPrefetch : public IRMutator { Stmt stmt; if (!body.same_as(op->body)) { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, std::move(body)); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } else { stmt = op; } @@ -300,7 +300,7 @@ class ReducePrefetchDimension : public IRMutator { stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic)); for (size_t i = 0; i < index_names.size(); ++i) { - stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2], + stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2] - 1, ForType::Serial, Partition::Auto, DeviceAPI::None, stmt); } debug(5) << "\nReduce prefetch to " << max_dim << " dim:\n" @@ -371,7 +371,7 @@ class SplitPrefetch : public IRMutator { vector args = {base, std::move(new_offset), std::move(new_extent), std::move(new_stride)}; stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic)); for (size_t i = 0; i < index_names.size(); ++i) { - stmt = For::make(index_names[i], 0, extents[i], + stmt = For::make(index_names[i], 0, extents[i] - 1, ForType::Serial, Partition::Auto, DeviceAPI::None, stmt); } debug(5) << "\nSplit prefetch to max of " << max_byte_size << " bytes:\n" diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 9d38efaaf80a..1c76965ac2be 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -91,24 +91,23 @@ class PrintLoopNest : public IRVisitor { // If the min or extent are constants, print them. At this // stage they're all variables. - Expr min_val = op->min, extent_val = op->extent; + Expr min_val = op->min, max_val = op->max; const Variable *min_var = min_val.as(); - const Variable *extent_var = extent_val.as(); + const Variable *max_var = max_val.as(); if (min_var) { if (const Expr *e = constants.find(min_var->name)) { min_val = *e; } } - if (extent_var) { - if (const Expr *e = constants.find(extent_var->name)) { - extent_val = *e; + if (max_var) { + if (const Expr *e = constants.find(max_var->name)) { + max_val = *e; } } - if (extent_val.defined() && is_const(extent_val) && + if (max_val.defined() && is_const(max_val) && min_val.defined() && is_const(min_val)) { - Expr max_val = simplify(min_val + extent_val - 1); out << " in [" << min_val << ", " << max_val << "]"; } diff --git a/src/Profiling.cpp b/src/Profiling.cpp index 3c7b5d6e0090..c8054e83544e 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -502,7 +502,7 @@ class InjectProfiling : public IRMutator { most_recently_set_func = -1; } - Stmt stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + Stmt stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); if (update_active_threads) { if (Internal::is_gpu(op->for_type)) { diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp index d20c1e42ce3a..a96030291a44 100644 --- a/src/RebaseLoopsToZero.cpp +++ b/src/RebaseLoopsToZero.cpp @@ -39,7 +39,7 @@ class RebaseLoopsToZero : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(name, 0, op->max - op->min, op->for_type, op->partition_policy, op->device_api, body); } } }; diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index c6036b99ec5a..9667aafe891a 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -355,8 +355,8 @@ class RemoveUndef : public IRMutator { if (!min.defined()) { return Stmt(); } - Expr extent = mutate(op->extent); - if (!extent.defined()) { + Expr max = mutate(op->max); + if (!max.defined()) { return Stmt(); } Stmt body = mutate(op->body); @@ -364,11 +364,11 @@ class RemoveUndef : public IRMutator { return Stmt(); } if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body)) { return op; } else { - return For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 19c3de055001..3df28f91bbb2 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -418,8 +418,8 @@ Stmt build_loop_nest( internal_assert(container.type == Container::For); const Dim &dim = stage_s.dims()[container.dim_idx]; Expr min = Variable::make(Int(32), container.name + ".loop_min"); - Expr extent = Variable::make(Int(32), container.name + ".loop_extent"); - stmt = For::make(container.name, min, extent, dim.for_type, dim.partition_policy, dim.device_api, stmt); + Expr max = Variable::make(Int(32), container.name + ".loop_max"); + stmt = For::make(container.name, min, max, dim.for_type, dim.partition_policy, dim.device_api, stmt); } } @@ -983,7 +983,7 @@ class InjectStmt : public IRMutator { } else { return For::make(for_loop->name, for_loop->min, - for_loop->extent, + for_loop->max, for_loop->for_type, for_loop->partition_policy, for_loop->device_api, @@ -1059,9 +1059,9 @@ Stmt substitute_fused_bounds(Stmt s, const map &replacements) { Stmt visit(const For *op) override { const auto *min_var = op->min.as(); - const auto *extent_var = op->extent.as(); - if (min_var && extent_var) { - Expr min_val, extent_val; + const auto *max_var = op->max.as(); + if (min_var && max_var) { + Expr min_val, max_val; { const auto &it = replacements.find(min_var->name); if (it != replacements.end()) { @@ -1069,12 +1069,12 @@ Stmt substitute_fused_bounds(Stmt s, const map &replacements) { } } { - const auto &it = replacements.find(extent_var->name); + const auto &it = replacements.find(max_var->name); if (it != replacements.end()) { - extent_val = it->second; + max_val = it->second; } } - if (!min_val.defined() || !extent_val.defined()) { + if (!min_val.defined() || !max_val.defined()) { return IRMutator::visit(op); } @@ -1084,7 +1084,7 @@ Stmt substitute_fused_bounds(Stmt s, const map &replacements) { ForType for_type = op->for_type; DeviceAPI device_api = op->device_api; - if (is_const_one(extent_val)) { + if (equal(min_val, max_val)) { // This is the child loop of a fused group. The real loop of the // fused group is the loop of the parent function of the fused // group. This child loop is just a scheduling point, and should @@ -1095,13 +1095,13 @@ Stmt substitute_fused_bounds(Stmt s, const map &replacements) { } Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_extent"), + Variable::make(Int(32), new_var + ".loop_max"), for_type, op->partition_policy, device_api, body); // Add let stmts defining the bound of the renamed for-loop. stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); - stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); + stmt = LetStmt::make(new_var + ".loop_max", max_val, stmt); + stmt = LetStmt::make(new_var + ".loop_extent", simplify((max_val - min_val) + 1), stmt); // Replace any reference to the old loop name with the new one. stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); return stmt; @@ -1143,7 +1143,7 @@ Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_alias body = LetStmt::make(alias, var, body); } - return For::make(op->name, op->min, op->extent, op->for_type, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } @@ -1171,7 +1171,7 @@ class ShiftLoopNest : public IRMutator { internal_assert(op); Expr adjusted = Variable::make(Int(32), op->name) + iter->second; Stmt body = substitute(op->name, adjusted, op->body); - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } return stmt; } @@ -1347,7 +1347,7 @@ class InjectFunctionRealization : public IRMutator { } else { return For::make(for_loop->name, for_loop->min, - for_loop->extent, + for_loop->max, for_loop->for_type, for_loop->partition_policy, for_loop->device_api, @@ -1954,7 +1954,7 @@ class ComputeLegalSchedules : public IRVisitor { sites.push_back({f->is_parallel(), is_gpu_block, loop_level}); f->min.accept(this); - f->extent.accept(this); + f->max.accept(this); f->body.accept(this); sites.pop_back(); @@ -2556,7 +2556,7 @@ class RemoveLoopsOverOutermost : public IRMutator { Stmt visit(const For *op) override { if (ends_with(op->name, ".__outermost") && - is_const_one(simplify(op->extent)) && + can_prove(op->min == op->max) && op->device_api == DeviceAPI::None) { return mutate(substitute(op->name, op->min, op->body)); } else { @@ -2590,7 +2590,7 @@ Stmt schedule_functions(const vector &outputs, const Target &target, bool &any_memoized) { string root_var = LoopLevel::root().lock().to_string(); - Stmt s = For::make(root_var, 0, 1, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); + Stmt s = For::make(root_var, 0, 0, ForType::Serial, Partition::Never, DeviceAPI::Host, Evaluate::make(0)); any_memoized = false; diff --git a/src/SelectGPUAPI.cpp b/src/SelectGPUAPI.cpp index 504bda8f3bb8..ec73c883e955 100644 --- a/src/SelectGPUAPI.cpp +++ b/src/SelectGPUAPI.cpp @@ -35,7 +35,7 @@ class SelectGPUAPI : public IRMutator { internal_assert(op); if (op->device_api != selected_api) { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, selected_api, op->body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, selected_api, op->body); } return stmt; } diff --git a/src/Serialization.cpp b/src/Serialization.cpp index d731d9c9d85c..33de404edffe 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -444,7 +444,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu const auto *const for_stmt = stmt.as(); const auto name_serialized = serialize_string(builder, for_stmt->name); const auto min_serialized = serialize_expr(builder, for_stmt->min); - const auto extent_serialized = serialize_expr(builder, for_stmt->extent); + const auto max_serialized = serialize_expr(builder, for_stmt->max); const Serialize::ForType for_type = serialize_for_type(for_stmt->for_type); const Serialize::Partition partition_policy = serialize_partition(for_stmt->partition_policy); const Serialize::DeviceAPI device_api = serialize_device_api(for_stmt->device_api); @@ -452,7 +452,7 @@ std::pair> Serializer::serialize_stmt(FlatBufferBu return std::make_pair(Serialize::Stmt::For, Serialize::CreateFor(builder, name_serialized, min_serialized.first, min_serialized.second, - extent_serialized.first, extent_serialized.second, + max_serialized.first, max_serialized.second, for_type, partition_policy, device_api, body_serialized.first, body_serialized.second) .Union()); diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index cd2c440de6ba..eaed86fa9a2e 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -145,7 +145,7 @@ Stmt Simplify::visit(const IfThenElse *op) { then_case); } else if (then_for && !else_case.defined() && - equal(unwrapped_condition, 0 < then_for->extent)) { + equal(unwrapped_condition, then_for->min <= then_for->max)) { // This guard is redundant return then_case; } else if (then_if && @@ -203,21 +203,21 @@ Stmt Simplify::visit(const AssertStmt *op) { } Stmt Simplify::visit(const For *op) { - ExprInfo min_info, extent_info; + ExprInfo min_info, max_info; Expr new_min = mutate(op->min, &min_info); if (in_unreachable) { return Evaluate::make(new_min); } - Expr new_extent = mutate(op->extent, &extent_info); + Expr new_max = mutate(op->max, &max_info); if (in_unreachable) { - return Evaluate::make(new_extent); + return Evaluate::make(new_max); } ScopedValue old_in_vector_loop(in_vector_loop, (in_vector_loop || op->for_type == ForType::Vectorized)); - Expr extent_positive = mutate(0 < new_extent, nullptr); + Expr extent_positive = mutate(new_min <= new_max, nullptr); if (is_const_zero(extent_positive)) { // This loop never runs return Evaluate::make(0); @@ -229,8 +229,8 @@ Stmt Simplify::visit(const For *op) { // at least one, so we can throw a max around the extent bounds. loop_var_info.bounds = - ConstantInterval::make_union(min_info.bounds, - min_info.bounds + max(extent_info.bounds, 1) - 1); + ConstantInterval::make_union(min(min_info.bounds, max_info.bounds), + max(min_info.bounds, max_info.bounds)); Stmt new_body; { ScopedBinding bind_if((loop_var_info.bounds.max_defined || @@ -244,10 +244,8 @@ Stmt Simplify::visit(const For *op) { // The loop variable will never exceed the loop bound. Expr loop_var = Variable::make(Int(32), op->name); - Expr new_max = mutate(new_min + new_extent, nullptr); - ScopedFact fact_loop_var_less_than_extent = scoped_truth(loop_var < new_max); - - ScopedFact fact_loop_var_ge_than_min = scoped_truth(new_min <= loop_var); + ScopedFact fact_loop_var_le_max = scoped_truth(loop_var <= new_max); + ScopedFact fact_loop_var_ge_min = scoped_truth(new_min <= loop_var); new_body = mutate(op->body); } @@ -258,38 +256,45 @@ Stmt Simplify::visit(const For *op) { // extent is greater than zero, then the code *outside* the loop must be // unreachable too, because if it weren't, it'd run the unreachable body // at least once. - in_unreachable = extent_info.bounds > 0; + in_unreachable = max_info.bounds >= min_info.bounds; return Evaluate::make(0); } if (const Acquire *acquire = new_body.as()) { if (is_no_op(acquire->body)) { // Rewrite iterated no-op acquires as a single acquire. - return Acquire::make(acquire->semaphore, mutate(acquire->count * new_extent, nullptr), acquire->body); + return Acquire::make(acquire->semaphore, mutate(acquire->count * ((new_max - new_min) + 1), nullptr), acquire->body); } } if (is_no_op(new_body)) { return new_body; - } else if (extent_info.bounds <= 0) { + } else if (max_info.bounds < min_info.bounds) { return Evaluate::make(0); - } else if (extent_info.bounds <= 1 && + } else if (equal(new_min, new_max) && + op->device_api == DeviceAPI::None) { + // Loop body runs exactly once + return mutate(LetStmt::make(op->name, new_min, new_body)); + } else if (max_info.bounds <= min_info.bounds && op->device_api == DeviceAPI::None) { // Loop body runs at most once Stmt s = LetStmt::make(op->name, new_min, new_body); - if (extent_info.bounds.contains(0)) { + if (!(max_info.bounds >= min_info.bounds)) { // Loop body might not run at all - s = IfThenElse::make(0 < new_extent, s); + s = IfThenElse::make(new_min <= new_max, s); } return mutate(s); - } else if (!stmt_uses_var(new_body, op->name) && !is_const_zero(op->min)) { - return For::make(op->name, make_zero(Int(32)), new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + } else if (Expr shifted_max; + !stmt_uses_var(new_body, op->name) && + !is_const_zero(new_min) && + is_const(shifted_max = mutate((new_max - new_min), nullptr))) { + return For::make(op->name, make_zero(Int(32)), shifted_max, op->for_type, op->partition_policy, op->device_api, new_body); } else if (op->min.same_as(new_min) && - op->extent.same_as(new_extent) && + op->max.same_as(new_max) && op->body.same_as(new_body)) { return op; } else { - return For::make(op->name, new_min, new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, new_body); } } diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 63640492de36..4a82456689c0 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -721,7 +721,7 @@ class SkipStages : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, std::move(body)); } } diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 69fa3198ceaf..4e658d7c2b9d 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -198,7 +198,7 @@ class RollFunc : public IRMutator { Stmt body = substitute(op->name, Variable::make(Int(32), new_name) + op->min, op->body); // use op->name *before* the re-assignment of result, which will clobber it loops_to_rebase.erase(op->name); - result = For::make(new_name, 0, op->extent, op->for_type, op->partition_policy, op->device_api, body); + result = For::make(new_name, 0, op->max - op->min, op->for_type, op->partition_policy, op->device_api, body); } return result; } @@ -561,21 +561,21 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { // It's not safe to enter an inner loop whose bounds depend on // the var we're sliding over. Expr min = expand_expr(op->min, scope); - Expr extent = expand_expr(op->extent, scope); + Expr max = expand_expr(op->max, scope); ScopedBinding<> bind(enclosing_loops, op->name); - if (is_const_one(extent)) { + if (equal(min, max)) { // Just treat it like a let Stmt s = LetStmt::make(op->name, min, op->body); s = mutate(s); // Unpack it back into the for const LetStmt *l = s.as(); internal_assert(l); - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, l->body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, l->body); } else if (is_monotonic(min, loop_var) != Monotonic::Constant || - is_monotonic(extent, loop_var) != Monotonic::Constant) { + is_monotonic(max, loop_var) != Monotonic::Constant) { debug(3) << "Not entering loop over " << op->name << " because the bounds depend on the var we're sliding over: " - << min << ", " << extent << "\n"; + << min << ", " << max << "\n"; return op; } else { return IRMutator::visit(op); @@ -793,8 +793,8 @@ class SlidingWindow : public IRMutator { string name = op->name; Stmt body = op->body; Expr loop_min = op->min; - Expr loop_extent = op->extent; - Expr loop_max = Variable::make(Int(32), op->name + ".loop_max"); + Expr loop_max = op->max; + Expr loop_extent = Variable::make(Int(32), op->name + ".loop_extent"); list> prev_loop_mins; list> new_lets; @@ -852,7 +852,7 @@ class SlidingWindow : public IRMutator { name = new_name; - // The new loop interval is the new loop min to the loop max. + // The new loop interval is the new loop min to the old loop max. new_lets.emplace_front(name + ".loop_min", new_loop_min); new_lets.emplace_front(name + ".loop_min.orig", loop_min); new_lets.emplace_front(name + ".loop_extent", (loop_max - loop_min) + 1); @@ -870,10 +870,10 @@ class SlidingWindow : public IRMutator { body = mutate(body); - if (body.same_as(op->body) && loop_min.same_as(op->min) && loop_extent.same_as(op->extent) && name == op->name) { + if (body.same_as(op->body) && loop_min.same_as(op->min) && loop_max.same_as(op->max) && name == op->name) { return op; } else { - Stmt result = For::make(name, loop_min, loop_extent, op->for_type, op->partition_policy, op->device_api, body); + Stmt result = For::make(name, loop_min, loop_max, op->for_type, op->partition_policy, op->device_api, body); if (!new_lets.empty()) { result = LetStmt::make(name + ".loop_max", loop_max, result); } @@ -913,13 +913,13 @@ class AddLoopMinOrig : public IRMutator { Stmt visit(const For *op) override { Stmt body = mutate(op->body); Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt result; - if (body.same_as(op->body) && min.same_as(op->min) && extent.same_as(op->extent)) { + if (body.same_as(op->body) && min.same_as(op->min) && max.same_as(op->max)) { result = op; } else { - result = For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + result = For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } return LetStmt::make(op->name + ".loop_min.orig", Variable::make(Int(32), op->name + ".loop_min"), result); } diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index 5880ea3b4008..85691921bc8d 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -114,7 +114,7 @@ class FindStridedLoads : public IRVisitor { } void visit(const For *op) override { - if (can_prove(op->extent > 0)) { + if (can_prove(op->min <= op->max)) { // The loop body definitely runs IRVisitor::visit(op); } else { diff --git a/src/StmtToHTML.cpp b/src/StmtToHTML.cpp index fbb327fc2c68..766ea61973b2 100644 --- a/src/StmtToHTML.cpp +++ b/src/StmtToHTML.cpp @@ -424,8 +424,8 @@ class IRCostModel : public IRVisitor { // The cost of a loop-node essentially depends on its iteration // count. The cost model currently ignores such costs. IRVisitor::visit(op); - set_compute_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()}); - set_data_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()}); + set_compute_costs(op, 0, {op->min.get(), op->max.get(), op->body.get()}, {op->min.get(), op->max.get()}); + set_data_costs(op, 0, {op->min.get(), op->max.get(), op->body.get()}, {op->min.get(), op->max.get()}); } void visit(const Acquire *op) override { @@ -1694,7 +1694,7 @@ class HTMLCodePrinter : public IRVisitor { print_html_element("span", "matched", ", "); print(op->min); print_html_element("span", "matched", ", "); - print(op->extent); + print(op->max); print_html_element("span", "matched", ")"); // Open code block to hold function body diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index 11f5e8e20d24..a7c8f5208c22 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -569,12 +569,12 @@ class HoistStorage : public IRMutator { Stmt visit(const For *op) override { Expr expanded_min = op->min; - Expr expanded_extent = op->extent; + Expr expanded_max = op->max; // Iterate from innermost outwards for (auto &storage : reverse_view(hoisted_storages)) { expanded_min = simplify(expand_expr(expanded_min, storage.scope)); - expanded_extent = expand_expr(expanded_extent, storage.scope); - auto loop_bounds = Interval(expanded_min, simplify(expanded_min + expanded_extent - 1)); + expanded_max = expand_expr(expanded_max, storage.scope); + auto loop_bounds = Interval(expanded_min, expanded_max); storage.loop_vars.emplace_back(op->name, loop_bounds); } diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index 04e743e33fbd..19e064ac3781 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -542,10 +542,10 @@ class AttemptStorageFoldingOfFunction : public IRMutator { string dynamic_footprint; Scope bounds; - bounds.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1))); + bounds.push(op->name, Interval(op->min, op->max)); Scope steady_bounds; - steady_bounds.push(op->name, Interval(simplify(op->min + 1), simplify(op->min + op->extent - 1))); + steady_bounds.push(op->name, Interval(simplify(op->min + 1), op->max)); HasExternConsumer has_extern_consumer(func.name()); body.accept(&has_extern_consumer); @@ -881,7 +881,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { // for further folding opportunities // recursively. } else if (!body.same_as(op->body)) { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); break; } else { stmt = op; @@ -900,7 +900,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { if (body.same_as(op->body)) { stmt = op; } else { - stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + stmt = For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } if (func.schedule().async() && !dynamic_footprint.empty()) { diff --git a/src/Substitute.cpp b/src/Substitute.cpp index 6a7cba7fd589..603a7cc8aa5a 100644 --- a/src/Substitute.cpp +++ b/src/Substitute.cpp @@ -83,17 +83,17 @@ class Substitute : public IRMutator { Stmt visit(const For *op) override { Expr new_min = mutate(op->min); - Expr new_extent = mutate(op->extent); + Expr new_max = mutate(op->max); hidden.push(op->name); Stmt new_body = mutate(op->body); hidden.pop(op->name); if (new_min.same_as(op->min) && - new_extent.same_as(op->extent) && + new_max.same_as(op->max) && new_body.same_as(op->body)) { return op; } else { - return For::make(op->name, new_min, new_extent, op->for_type, op->partition_policy, op->device_api, new_body); + return For::make(op->name, new_min, new_max, op->for_type, op->partition_policy, op->device_api, new_body); } } }; diff --git a/src/TrimNoOps.cpp b/src/TrimNoOps.cpp index b4ec415072c9..1842a702fab4 100644 --- a/src/TrimNoOps.cpp +++ b/src/TrimNoOps.cpp @@ -125,12 +125,12 @@ class IsNoOp : public IRVisitor { condition = const_true(); op->body.accept(this); Scope varying; - varying.push(op->name, Interval(op->min, op->min + op->extent - 1)); + varying.push(op->name, Interval(op->min, op->max)); condition = simplify(common_subexpression_elimination(condition)); debug(3) << "About to relax over " << op->name << " : " << condition << "\n"; condition = and_condition_over_domain(condition, varying); debug(3) << "Relaxed: " << condition << "\n"; - condition = make_and(old_condition, make_or(condition, simplify(op->extent <= 0))); + condition = make_and(old_condition, make_or(condition, simplify(op->max < op->min))); } void visit(const IfThenElse *op) override { @@ -334,11 +334,11 @@ class SimplifyUsingBounds : public IRMutator { Stmt visit(const For *op) override { // Simplify the loop bounds. Expr min = mutate(op->min); - Expr extent = mutate(op->extent); - containing_loops.push_back({op->name, {min, min + extent - 1}}); + Expr max = mutate(op->max); + containing_loops.push_back({op->name, {min, max}}); Stmt body = mutate(op->body); containing_loops.pop_back(); - return For::make(op->name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, op->for_type, op->partition_policy, op->device_api, body); } public: @@ -380,7 +380,7 @@ class TrimNoOps : public IRMutator { if (body.same_as(op->body)) { return op; } else { - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -393,7 +393,7 @@ class TrimNoOps : public IRMutator { if (i.is_everything()) { // Nope. - return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body); } if (i.is_empty()) { @@ -414,29 +414,22 @@ class TrimNoOps : public IRMutator { Expr new_max_var = Variable::make(Int(32), new_max_name); Expr old_max_var = Variable::make(Int(32), old_max_name); - // Convert max to max-plus-one - if (i.has_upper_bound()) { - i.max = i.max + 1; - } - // Truncate the loop bounds to the region over which it's not // a no-op. - Expr old_max = op->min + op->extent; + Expr old_max = op->max; Expr new_min, new_max; if (i.has_lower_bound()) { - new_min = clamp(i.min, op->min, old_max_var); + new_min = clamp(i.min, op->min, old_max_var + 1); } else { new_min = op->min; } if (i.has_upper_bound()) { - new_max = clamp(i.max, new_min_var, old_max_var); + new_max = clamp(i.max, new_min_var - 1, old_max_var); } else { new_max = old_max; } - Expr new_extent = new_max_var - new_min_var; - - Stmt stmt = For::make(op->name, new_min_var, new_extent, op->for_type, op->partition_policy, op->device_api, body); + Stmt stmt = For::make(op->name, new_min_var, new_max_var, op->for_type, op->partition_policy, op->device_api, body); stmt = LetStmt::make(new_max_name, new_max, stmt); stmt = LetStmt::make(new_min_name, new_min, stmt); stmt = LetStmt::make(old_max_name, old_max, stmt); diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 9dc92c780b3a..91f0279de04c 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -88,7 +88,7 @@ class UniquifyVariableNames : public IRMutator { Stmt visit(const For *op) override { Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); string new_name = make_new_name(op->name); Stmt body = mutate(op->body); renaming.pop(op->name); @@ -96,10 +96,10 @@ class UniquifyVariableNames : public IRMutator { if (new_name == op->name && body.same_as(op->body) && min.same_as(op->min) && - extent.same_as(op->extent)) { + max.same_as(op->max)) { return op; } else { - return For::make(new_name, min, extent, op->for_type, op->partition_policy, op->device_api, body); + return For::make(new_name, min, max, op->for_type, op->partition_policy, op->device_api, body); } } @@ -153,7 +153,7 @@ class FindFreeVars : public IRVisitor { void visit(const For *op) override { op->min.accept(this); - op->extent.accept(this); + op->max.accept(this); { ScopedBinding<> bind(scope, op->name); op->body.accept(this); diff --git a/src/UnrollLoops.cpp b/src/UnrollLoops.cpp index 2823c8b9ac9f..ffcba564966a 100644 --- a/src/UnrollLoops.cpp +++ b/src/UnrollLoops.cpp @@ -16,7 +16,8 @@ class UnrollLoops : public IRMutator { Stmt visit(const For *for_loop) override { if (for_loop->for_type == ForType::Unrolled) { Stmt body = for_loop->body; - const IntImm *e = for_loop->extent.as(); + Expr extent = simplify(for_loop->extent()); + const IntImm *e = extent.as(); internal_assert(e) << "Loop over " << for_loop->name << " should have had a constant extent\n"; diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index d06cc0815300..e8935aca6257 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -353,7 +353,7 @@ class SerializeLoops : public IRMutator { Stmt visit(const For *op) override { if (op->for_type == ForType::Vectorized) { - return For::make(op->name, op->min, op->extent, + return For::make(op->name, op->min, op->max, ForType::Serial, op->partition_policy, op->device_api, mutate(op->body)); } @@ -950,7 +950,7 @@ class VectorSubs : public IRMutator { ForType for_type = op->for_type; Expr min = mutate(op->min); - Expr extent = mutate(op->extent); + Expr max = mutate(op->max); Stmt body = op->body; @@ -958,20 +958,21 @@ class VectorSubs : public IRMutator { // Rebase the loop to zero and try again Expr var = Variable::make(Int(32), op->name); Stmt body = substitute(op->name, var + op->min, op->body); - Stmt transformed = For::make(op->name, 0, op->extent, for_type, op->partition_policy, op->device_api, body); + Stmt transformed = For::make(op->name, 0, simplify(op->max - op->min), for_type, op->partition_policy, op->device_api, body); return mutate(transformed); } - if (extent.type().is_vector()) { + if (max.type().is_vector()) { // We'll iterate up to the max over the lanes, but // inject an if statement inside the loop that stops // each lane from going too far. - extent = bounds_of_lanes(extent).max; + max = bounds_of_lanes(max).max; Expr var = Variable::make(Int(32), op->name); - body = IfThenElse::make(likely(var < op->min + op->extent), body); + body = IfThenElse::make(likely(var <= max), body); } + Expr extent = simplify((max - min) + 1); if (op->for_type == ForType::Vectorized) { const IntImm *extent_int = extent.as(); internal_assert(extent_int) @@ -980,7 +981,11 @@ class VectorSubs : public IRMutator { user_error << "Loop over " << op->name << " has extent " << extent << ". Can only vectorize loops over a " - << "constant extent > 1\n"; + << "constant extent > 1\n" + << "Original min: " << op->min << "\n" + << "Original max: " << op->max << "\n" + << "Mutated min: " << min << "\n" + << "Mutated max: " << max << "\n"; } vectorized_vars.push_back({op->name, min, (int)extent_int->value}); @@ -1022,12 +1027,12 @@ class VectorSubs : public IRMutator { body = mutate(body); if (min.same_as(op->min) && - extent.same_as(op->extent) && + max.same_as(op->max) && body.same_as(op->body) && for_type == op->for_type) { return op; } else { - return For::make(op->name, min, extent, for_type, op->partition_policy, op->device_api, body); + return For::make(op->name, min, max, for_type, op->partition_policy, op->device_api, body); } } } @@ -1318,7 +1323,8 @@ class VectorSubs : public IRMutator { for (int ix = vectorized_vars.size() - 1; ix >= 0; ix--) { s = For::make(vectorized_vars[ix].name, vectorized_vars[ix].min, - vectorized_vars[ix].lanes, ForType::Serial, Partition::Auto, DeviceAPI::None, s); + vectorized_vars[ix].min + vectorized_vars[ix].lanes - 1, + ForType::Serial, Partition::Auto, DeviceAPI::None, s); } return s; @@ -1581,10 +1587,11 @@ class VectorizeLoops : public IRMutator { Stmt visit(const For *for_loop) override { Stmt stmt; if (for_loop->for_type == ForType::Vectorized) { - const IntImm *extent = for_loop->extent.as(); + Expr loop_extent = simplify(for_loop->extent()); + const IntImm *extent = loop_extent.as(); if (!extent || extent->value <= 1) { user_error << "Loop over " << for_loop->name - << " has extent " << for_loop->extent + << " has extent " << loop_extent << ". Can only vectorize loops over a " << "constant extent > 1\n"; } diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index 499488ce8b95..7f3492684743 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -7,7 +7,7 @@ file_identifier "HLDE"; file_extension "hlpipe"; enum SerializationVersionMajor: int { - Value = 18 + Value = 21 } enum SerializationVersionMinor: int { // 0 = Unstable @@ -15,7 +15,7 @@ enum SerializationVersionMinor: int { Value = 0 } enum SerializationVersionPatch: int { - Value = 1 + Value = 0 } // from src/IR.cpp @@ -143,7 +143,7 @@ table ProducerConsumer { table For { name: string; min: Expr; - extent: Expr; + max: Expr; for_type: ForType; partition_policy: Partition; device_api: DeviceAPI; diff --git a/test/correctness/fuse_gpu_threads.cpp b/test/correctness/fuse_gpu_threads.cpp index efd690c4ef4c..5c203846f0d9 100644 --- a/test/correctness/fuse_gpu_threads.cpp +++ b/test/correctness/fuse_gpu_threads.cpp @@ -9,9 +9,9 @@ class CheckThreadExtent : public IRVisitor { if (op->for_type == ForType::GPUThread) { // Assert the min and extent to be 0 and 16 for this particular test case auto min = as_const_int(op->min); - auto extent = as_const_int(op->extent); + auto max = as_const_int(op->max); assert(min && (*min == 0)); - assert(extent && (*extent == 16)); + assert(max && (*max == 15)); } IRVisitor::visit(op); } diff --git a/test/correctness/out_constraint.cpp b/test/correctness/out_constraint.cpp index 87dfa1a70df8..60fbeb8de5a9 100644 --- a/test/correctness/out_constraint.cpp +++ b/test/correctness/out_constraint.cpp @@ -26,9 +26,9 @@ class CheckLoops : public IRVisitor { using IRVisitor::visit; void visit(const For *op) override { - std::cout << "for(" << op->name << ", " << op->min << ", " << op->extent << ")\n"; + std::cout << "for(" << op->name << ", " << op->min << ", " << op->max << ")\n"; check_int(op->min, 0); - check_int(op->extent, size); + check_int(op->max, size - 1); ++count; IRVisitor::visit(op); } diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 052762184f46..08119357fb4b 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -1736,10 +1736,10 @@ void check_boolean() { // A for loop is also an if statement that the extent is greater than zero Stmt body = AssertStmt::make(y == z, y); Stmt loop = For::make("t", 0, x, ForType::Serial, Partition::Auto, DeviceAPI::None, body); - check(IfThenElse::make(0 < x, loop), loop); + check(IfThenElse::make(0 <= x, loop), loop); - // A for loop where the extent is exactly one is just the body - check(IfThenElse::make(x == 1, loop), IfThenElse::make(x == 1, body)); + // A for loop where the min equals the max is just the body + check(IfThenElse::make(x == 0, loop), IfThenElse::make(x == 0, body)); // Check we can learn from conditions on variables check(IfThenElse::make(x < 5, not_no_op(min(x, 17))), @@ -2419,7 +2419,7 @@ int main(int argc, char **argv) { } { - Stmt body = AssertStmt::make(x > 0, y); + Stmt body = AssertStmt::make(x >= 0, y); check(For::make("t", 0, x, ForType::Serial, Partition::Auto, DeviceAPI::None, body), Evaluate::make(0)); } From 13ba99a033c5fe328b116f6e9393fadc24ccbc4c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 3 Nov 2025 17:03:45 -0800 Subject: [PATCH 2/5] Demote loop_min/loop_max/loop_extent to regular let stmts Now they're just regular old lets placed there for the convenience of computing splits. They don't need to accurately reflect the loop, or be preserved through lowering passes, etc. loop_extent was deleted entirely. They do need to be preserved until computation bounds inference because they provide a way to compute loop bounds as a function of .min .max symbols that don't exist until after that point, so you can't move them around before then. This change was possible because allocation bounds inference doesn't need any assistance given loops that are min/max instead of min/extent. --- src/ApplySplit.cpp | 19 ++-- src/Bounds.cpp | 14 +-- src/BoundsInference.cpp | 5 +- src/CanonicalizeGPUVars.cpp | 35 -------- src/Func.cpp | 5 +- src/Lower.cpp | 2 +- src/RebaseLoopsToZero.cpp | 1 - src/ScheduleFunctions.cpp | 170 ++++++++++++++---------------------- src/SlidingWindow.cpp | 7 -- src/StorageFolding.cpp | 8 +- 10 files changed, 81 insertions(+), 185 deletions(-) diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index 22c3425c02a4..ddb9bc1098c5 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -22,7 +22,7 @@ vector apply_split(const Split &split, const string &prefix, Expr inner = Variable::make(Int(32), prefix + split.inner); Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); - Expr old_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); + Expr old_extent = (old_max - old_min) + 1; dim_extent_alignment[split.inner] = split.factor; @@ -135,10 +135,10 @@ vector apply_split(const Split &split, const string &prefix, // Define the inner and outer in terms of the fused var Expr fused = Variable::make(Int(32), prefix + split.old_var); Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min"); + Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max"); Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min"); - Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent"); - const Expr &factor = inner_extent; + const Expr &factor = (inner_max - inner_min) + 1; Expr inner = fused % factor + inner_min; Expr outer = fused / factor + outer_min; @@ -169,7 +169,6 @@ vector> compute_loop_bounds_after_split(const Split &spl // Define the bounds on the split dimensions using the bounds on the function args. vector> let_stmts; - Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max"); Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min"); switch (split.split_type) { @@ -178,24 +177,22 @@ vector> compute_loop_bounds_after_split(const Split &spl Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor; let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0); let_stmts.emplace_back(prefix + split.inner + ".loop_max", inner_extent - 1); - let_stmts.emplace_back(prefix + split.inner + ".loop_extent", inner_extent); let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0); let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1); - let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent); } break; case Split::FuseVars: { // Define bounds on the fused var using the bounds on the inner and outer - Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent"); - Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent"); - Expr fused_extent = inner_extent * outer_extent; + Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min"); + Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max"); + Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min"); + Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max"); + Expr fused_extent = (inner_max - inner_min + 1) * (outer_max - outer_min + 1); let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0); let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1); - let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent); } break; case Split::RenameVar: let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min); let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max); - let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent); break; } diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 8079d7c980e4..1218eb50239a 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -2967,18 +2967,8 @@ class BoxesTouched : public IRGraphVisitor { op->max.accept(this); } - Expr min_val, max_val; - if (const Interval *in = scope.find(op->name + ".loop_min")) { - min_val = in->min; - } else { - min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; - } - - if (const Interval *in = scope.find(op->name + ".loop_max")) { - max_val = in->max; - } else { - max_val = bounds_of_expr_in_scope(op->max, scope, func_bounds).max; - } + Expr min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; + Expr max_val = bounds_of_expr_in_scope(op->max, scope, func_bounds).max; push_var(op->name); { diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 6a0ffc3f8d19..6e8c2f2a5241 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -114,10 +114,7 @@ class BoundsOfInnerVar : public IRVisitor { } void visit(const For *op) override { - // At this stage of lowering, loop_min and loop_max - // conveniently exist in scope. - Interval in(Variable::make(Int(32), op->name + ".loop_min"), - Variable::make(Int(32), op->name + ".loop_max")); + Interval in(op->min, op->max); if (op->name == var) { result = in; diff --git a/src/CanonicalizeGPUVars.cpp b/src/CanonicalizeGPUVars.cpp index 49f1e37d2945..7ca9b7c4fbf5 100644 --- a/src/CanonicalizeGPUVars.cpp +++ b/src/CanonicalizeGPUVars.cpp @@ -90,18 +90,6 @@ class CanonicalizeGPUVars : public IRMutator { return name; } - std::string canonicalize_let(const std::string &name) { - if (ends_with(name, ".loop_max")) { - return find_replacement(".loop_max", name); - } else if (ends_with(name, ".loop_min")) { - return find_replacement(".loop_min", name); - } else if (ends_with(name, ".loop_extent")) { - return find_replacement(".loop_extent", name); - } else { - return name; - } - } - Stmt visit(const For *op) override { std::string name = op->name; Expr min = mutate(op->min); @@ -145,29 +133,6 @@ class CanonicalizeGPUVars : public IRMutator { } } - Stmt visit(const LetStmt *op) override { - vector> lets; - Stmt result; - - do { - lets.emplace_back(op->name, mutate(op->value)); - result = op->body; - } while ((op = op->body.as())); - - result = mutate(result); - - for (const auto &[var, value] : reverse_view(lets)) { - std::string name = canonicalize_let(var); - if (name != var) { - Expr new_var = Variable::make(Int(32), name); - result = substitute(var, new_var, result); - } - result = LetStmt::make(name, value, result); - } - - return result; - } - Stmt visit(const IfThenElse *op) override { Expr condition = mutate(op->condition); diff --git a/src/Func.cpp b/src/Func.cpp index cf8904ec2d31..696b4353dc2b 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -708,15 +708,14 @@ pair project_rdom(const vector &dims, con for (const auto &[var, min, extent] : rdom.domain()) { add_let(bounds_projection, var + ".loop_min", min); add_let(bounds_projection, var + ".loop_max", min + extent - 1); - add_let(bounds_projection, var + ".loop_extent", extent); } // Build the new RDom from the bounds_projection. vector new_rvars; for (const Dim &dim : dims) { const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min")); - const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent")); - new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, new_extent}); + const Expr new_max = simplify(bounds_projection.at(dim.var + ".loop_max")); + new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, (new_max - new_min) + 1}); } ReductionDomain new_rdom{new_rvars}; new_rdom.where(rdom.predicate()); diff --git a/src/Lower.cpp b/src/Lower.cpp index 605311113681..36eb9f057228 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -236,7 +236,7 @@ void lower_impl(const vector &output_funcs, log("Lowering after uniquifying variable names:", s); debug(1) << "Simplifying...\n"; - s = simplify(s, false); // Storage folding and allocation bounds inference needs .loop_max symbols + s = simplify(s); log("Lowering after first simplification:", s); debug(1) << "Simplifying correlated differences...\n"; diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp index a96030291a44..49f97126bb93 100644 --- a/src/RebaseLoopsToZero.cpp +++ b/src/RebaseLoopsToZero.cpp @@ -31,7 +31,6 @@ class RebaseLoopsToZero : public IRMutator { Stmt body = mutate(op->body); string name = op->name; if (!is_const_zero(op->min)) { - // Renaming the loop (intentionally) invalidates any .loop_min/.loop_max lets. name = op->name + ".rebased"; Expr loop_var = Variable::make(Int(32), name); body = LetStmt::make(op->name, loop_var + op->min, body); diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 3df28f91bbb2..69001feebf92 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -437,17 +437,18 @@ Stmt build_loop_nest( string o = prefix + Var::outermost().name(); stmt = LetStmt::make(o + ".loop_min", 0, stmt); stmt = LetStmt::make(o + ".loop_max", 0, stmt); - stmt = LetStmt::make(o + ".loop_extent", 1, stmt); } - // Define the loop mins and extents in terms of the mins and maxs produced by bounds inference + // Define the loop mins and extents in terms of the mins and maxs produced + // by bounds inference. These are simple new_var = old_var lets, but we + // can't just substitute because there are shadowed copies of .min/.max and + // the loop_min and loop_max must be in terms of the .min/.max at *this* + // loop level. + for (const std::string &i : dims) { string var = prefix + i; Expr max = Variable::make(Int(32), var + ".max"); Expr min = Variable::make(Int(32), var + ".min"); // Inject instance name here? (compute instance names during lowering) - stmt = LetStmt::make(var + ".loop_extent", - (max + 1) - min, - stmt); stmt = LetStmt::make(var + ".loop_min", min, stmt); stmt = LetStmt::make(var + ".loop_max", max, stmt); } @@ -460,7 +461,6 @@ Stmt build_loop_nest( Expr rmax = Variable::make(Int(32), p + ".max"); stmt = LetStmt::make(p + ".loop_min", rmin, stmt); stmt = LetStmt::make(p + ".loop_max", rmax, stmt); - stmt = LetStmt::make(p + ".loop_extent", rmax - rmin + 1, stmt); } return stmt; @@ -1013,23 +1013,19 @@ Stmt inject_stmt(Stmt root, Stmt injected, const LoopLevel &level) { class CollectBounds : public IRVisitor { public: template - static map collect_bounds(const T &node) { + static map collect_bounds(const T &node) { CollectBounds bounds; node.accept(&bounds); return bounds.bounds; } private: - map bounds; + map bounds; using IRVisitor::visit; - void visit(const LetStmt *op) override { - if (ends_with(op->name, ".loop_min") || - ends_with(op->name, ".loop_max") || - ends_with(op->name, ".loop_extent")) { - bounds.emplace(op->name, Variable::make(Int(32), op->name)); - } + void visit(const For *op) override { + bounds.emplace(op->name, Interval{op->min, op->max}); IRVisitor::visit(op); } }; @@ -1047,71 +1043,50 @@ string fused_name(const string &var) { // The bounds of every loop exist in 'replacements' should be replaced. The // loop is also renamed by adding '.fused' in the original name before the // variable name. -Stmt substitute_fused_bounds(Stmt s, const map &replacements) { +Stmt substitute_fused_bounds(Stmt s, const map &replacements) { if (!s.defined() || replacements.empty()) { return s; } class SubstituteFusedBounds : public IRMutator { - const map &replacements; + const map &replacements; using IRMutator::visit; Stmt visit(const For *op) override { - const auto *min_var = op->min.as(); - const auto *max_var = op->max.as(); - if (min_var && max_var) { - Expr min_val, max_val; - { - const auto &it = replacements.find(min_var->name); - if (it != replacements.end()) { - min_val = it->second; - } - } - { - const auto &it = replacements.find(max_var->name); - if (it != replacements.end()) { - max_val = it->second; - } - } - if (!min_val.defined() || !max_val.defined()) { - return IRMutator::visit(op); - } + auto it = replacements.find(op->name); + if (it == replacements.end()) { + return IRMutator::visit(op); + } + const Interval &i = it->second; - Stmt body = mutate(op->body); + Stmt body = mutate(op->body); - string new_var = fused_name(op->name); + string new_var = fused_name(op->name); + + ForType for_type = op->for_type; + DeviceAPI device_api = op->device_api; + if (equal(i.min, i.max)) { + // This is the child loop of a fused group. The real loop of the + // fused group is the loop of the parent function of the fused + // group. This child loop is just a scheduling point, and should + // never be a device transition, so we rewrite it to be a simple + // serial loop of extent 1." + for_type = ForType::Serial; + device_api = DeviceAPI::None; + } - ForType for_type = op->for_type; - DeviceAPI device_api = op->device_api; - if (equal(min_val, max_val)) { - // This is the child loop of a fused group. The real loop of the - // fused group is the loop of the parent function of the fused - // group. This child loop is just a scheduling point, and should - // never be a device transition, so we rewrite it to be a simple - // serial loop of extent 1." - for_type = ForType::Serial; - device_api = DeviceAPI::None; - } + Stmt stmt = For::make(new_var, i.min, i.max, + for_type, op->partition_policy, + device_api, body); - Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_max"), - for_type, op->partition_policy, device_api, body); - - // Add let stmts defining the bound of the renamed for-loop. - stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", max_val, stmt); - stmt = LetStmt::make(new_var + ".loop_extent", simplify((max_val - min_val) + 1), stmt); - // Replace any reference to the old loop name with the new one. - stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); - return stmt; - } else { - return IRMutator::visit(op); - } + // Replace any reference to the old loop name with the new one. + stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); + return stmt; } public: - explicit SubstituteFusedBounds(const map &r) + explicit SubstituteFusedBounds(const map &r) : replacements(r) { } } subs(replacements); @@ -1443,7 +1418,7 @@ class InjectFunctionRealization : public IRMutator { // Compute the shift factor required to align iteration of // a function stage with its fused parent loop nest. void compute_shift_factor(const Function &f, const string &prefix, const Definition &def, - map &bounds, map &shifts) { + map &bounds, map &shifts) { if (!def.defined()) { return; } @@ -1496,29 +1471,28 @@ class InjectFunctionRealization : public IRMutator { internal_assert(parent_var_index >= 0); string parent_var = parent_dims[parent_var_index].var; - auto it_min = bounds.find(prefix + var + ".loop_min"); - auto it_max = bounds.find(prefix + var + ".loop_max"); - internal_assert((it_min != bounds.end()) && (it_max != bounds.end())); + auto it = bounds.find(prefix + var); + internal_assert(it != bounds.end()); if (iter->second == LoopAlignStrategy::AlignStart) { - auto parent_min = bounds.find(parent_prefix + parent_var + ".loop_min"); + auto parent_min = bounds.find(parent_prefix + parent_var); internal_assert(parent_min != bounds.end()); - shift_val = parent_min->second - it_min->second; + shift_val = parent_min->second.min - it->second.min; } else { - auto parent_max = bounds.find(parent_prefix + parent_var + ".loop_max"); + auto parent_max = bounds.find(parent_prefix + parent_var); internal_assert(parent_max != bounds.end()); - shift_val = parent_max->second - it_max->second; + shift_val = parent_max->second.max - it->second.max; } internal_assert(shift_val.defined()); shifts.emplace(prefix + var, simplify(-shift_val)); - it_min->second = simplify(shift_val + it_min->second); - it_max->second = simplify(shift_val + it_max->second); + it->second.min = simplify(shift_val + it->second.min); + it->second.max = simplify(shift_val + it->second.max); } } Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update, - map &replacements, + map &replacements, vector> &add_lets, map> &aliases) { const vector &dims = def.schedule().dims(); // From inner to outer @@ -1556,9 +1530,7 @@ class InjectFunctionRealization : public IRMutator { internal_assert(dim2_idx < (int)dims_2.size()); string var = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; - replacements.emplace(var + ".loop_extent", make_const(Int(32), 1)); - replacements.emplace(var + ".loop_min", val); - replacements.emplace(var + ".loop_max", val); + replacements.emplace(var, Interval::single_point(val)); string var_fused = fused_name(var_orig); aliases[var_fused].emplace(std::move(var_orig)); @@ -1616,8 +1588,8 @@ class InjectFunctionRealization : public IRMutator { // realized in the group) with union of the bounds of the fused group. Stmt replace_parent_bound_with_union_bound(const string &func, int stage, const Definition &def, Stmt produce, - const map &bounds, - map &replacements) { + const map &bounds, + map &replacements) { if (def.schedule().fused_pairs().empty()) { return produce; @@ -1650,33 +1622,20 @@ class InjectFunctionRealization : public IRMutator { string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; - internal_assert(bounds.count(var_2 + ".loop_min")); - internal_assert(bounds.count(var_2 + ".loop_max")); - internal_assert(bounds.count(var_2 + ".loop_extent")); - Expr min_2 = bounds.find(var_2 + ".loop_min")->second; - Expr max_2 = bounds.find(var_2 + ".loop_max")->second; - Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second; - - internal_assert(bounds.count(var_1 + ".loop_min")); - internal_assert(bounds.count(var_1 + ".loop_max")); - internal_assert(bounds.count(var_1 + ".loop_extent")); - - Expr min_1, max_1; - const auto &it = replacements.find(var_1 + ".loop_min"); + + Interval i_1; + Interval i_2 = bounds.find(var_2)->second; + + const auto &it = replacements.find(var_1); if (it == replacements.end()) { - min_1 = bounds.find(var_1 + ".loop_min")->second; - max_1 = bounds.find(var_1 + ".loop_max")->second; + i_1 = bounds.find(var_1)->second; } else { - min_1 = replacements.find(var_1 + ".loop_min")->second; - max_1 = replacements.find(var_1 + ".loop_max")->second; + i_1 = it->second; } // Extent is computed from min/max, so we don't find() it earlier. - replacements[var_1 + ".loop_min"] = simplify(min(min_1, min_2)); - replacements[var_1 + ".loop_max"] = simplify(max(max_1, max_2)); - replacements[var_1 + ".loop_extent"] = - simplify((replacements[var_1 + ".loop_max"] + 1) - - replacements[var_1 + ".loop_min"]); + replacements[var_1] = Interval{simplify(min(i_1.min, i_2.min)), + simplify(max(i_1.max, i_2.max))}; } } @@ -1690,8 +1649,8 @@ class InjectFunctionRealization : public IRMutator { } Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, - const map &bounds) { - map replacements; + const map &bounds) { + map replacements; int stage = 0; produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements); @@ -1828,7 +1787,7 @@ class InjectFunctionRealization : public IRMutator { // Build the loops. Stmt producer; - map replacements; + map replacements; vector> add_lets; map> aliases; @@ -2565,8 +2524,7 @@ class RemoveLoopsOverOutermost : public IRMutator { } Stmt visit(const LetStmt *op) override { - if (ends_with(op->name, ".__outermost.loop_extent") || - ends_with(op->name, ".__outermost.loop_min") || + if (ends_with(op->name, ".__outermost.loop_min") || ends_with(op->name, ".__outermost.loop_max")) { return mutate(substitute(op->name, simplify(op->value), op->body)); } else { diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 4e658d7c2b9d..3f779cb0bca2 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -794,7 +794,6 @@ class SlidingWindow : public IRMutator { Stmt body = op->body; Expr loop_min = op->min; Expr loop_max = op->max; - Expr loop_extent = Variable::make(Int(32), op->name + ".loop_extent"); list> prev_loop_mins; list> new_lets; @@ -841,11 +840,9 @@ class SlidingWindow : public IRMutator { // Update the loop body to use the adjusted loop min. string new_name = name + ".$n"; loop_min = Variable::make(Int(32), new_name + ".loop_min"); - loop_extent = Variable::make(Int(32), new_name + ".loop_extent"); body = substitute({ {name, Variable::make(Int(32), new_name)}, {name + ".loop_min", loop_min}, - {name + ".loop_extent", loop_extent}, }, body); body = SubstitutePrefetchVar(name, new_name).mutate(body); @@ -855,7 +852,6 @@ class SlidingWindow : public IRMutator { // The new loop interval is the new loop min to the old loop max. new_lets.emplace_front(name + ".loop_min", new_loop_min); new_lets.emplace_front(name + ".loop_min.orig", loop_min); - new_lets.emplace_front(name + ".loop_extent", (loop_max - loop_min) + 1); } if (slid_dims.size() > old_slid_dims_size) { @@ -874,9 +870,6 @@ class SlidingWindow : public IRMutator { return op; } else { Stmt result = For::make(name, loop_min, loop_max, op->for_type, op->partition_policy, op->device_api, body); - if (!new_lets.empty()) { - result = LetStmt::make(name + ".loop_max", loop_max, result); - } for (const auto &i : new_lets) { result = LetStmt::make(i.first, i.second, result); } diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index 19e064ac3781..d7e2db52e17e 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -536,8 +536,6 @@ class AttemptStorageFoldingOfFunction : public IRMutator { Box box = box_union(provided, required); Expr loop_var = Variable::make(Int(32), op->name); - Expr loop_min = Variable::make(Int(32), op->name + ".loop_min"); - Expr loop_max = Variable::make(Int(32), op->name + ".loop_max"); string dynamic_footprint; @@ -735,7 +733,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { } else { // The max of the extent over all values of the loop variable must be a constant Scope scope; - scope.push(op->name, Interval(loop_min, loop_max)); + scope.push(op->name, Interval(op->min, op->max)); Expr max_extent = find_constant_bound(extent, Direction::Upper, scope); scope.pop(op->name); @@ -825,8 +823,8 @@ class AttemptStorageFoldingOfFunction : public IRMutator { // On the first iteration, we need to acquire the extent of the region shared // between the producer and consumer, and we need to release it on the last // iteration. - to_acquire = select(loop_var > loop_min, to_acquire, extent); - to_release = select(loop_var < loop_max, to_release, extent); + to_acquire = select(loop_var > op->min, to_acquire, extent); + to_release = select(loop_var < op->max, to_release, extent); // We may need dynamic assertions that a positive // amount of the semaphore is acquired/released, From 4b20ae0b9b5cf29155192279d85c42ba79a40b7a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 7 Nov 2025 11:08:27 -0800 Subject: [PATCH 3/5] Fix sense of comparison in vulkan backend Also, simplify the loop extents when offloading GPU kernels and add extra outputs in the bgu makefile. --- apps/bgu/Makefile | 4 ++-- src/CodeGen_Vulkan_Dev.cpp | 2 +- src/OffloadGPULoops.cpp | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/bgu/Makefile b/apps/bgu/Makefile index 8eb687ec064a..a75b623cfcc1 100644 --- a/apps/bgu/Makefile +++ b/apps/bgu/Makefile @@ -16,11 +16,11 @@ $(GENERATOR_BIN)/bgu.generator: bgu_generator.cpp $(GENERATOR_DEPS) $(BIN)/%/bgu.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu -o $(BIN)/$* target=$*-no_runtime + $< -g bgu -f bgu -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime $(BIN)/%/bgu_auto_schedule.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) - $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016 + $< -g bgu -f bgu_auto_schedule -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime autoscheduler=Mullapudi2016 $(BIN)/%/runtime.a: $(GENERATOR_BIN)/bgu.generator @mkdir -p $(@D) diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 6bd33d6bc6b8..671f923ec183 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -1756,7 +1756,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { SpvId loop_test_type_id = builder.declare_type(Bool()); SpvId loop_test_id = builder.reserve_id(SpvResultId); builder.append(SpvFactory::load(index_type_id, loop_index_id, loop_var_id)); - builder.append(SpvFactory::integer_less_than_equal(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_uint())); + builder.append(SpvFactory::integer_less_than_equal(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_int())); builder.append(SpvFactory::conditional_branch(loop_test_id, body_block_id, merge_block_id)); } builder.leave_block(); diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index aa3e8422d3fd..e93f67e8bfef 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -15,6 +15,7 @@ #include "IRPrinter.h" #include "InjectHostDevBufferCopies.h" #include "OffloadGPULoops.h" +#include "Simplify.h" #include "Util.h" namespace Halide { @@ -55,10 +56,10 @@ class ExtractBounds : public IRVisitor { for (int i = 0; i < 3; i++) { if (ends_with(op->name, gpu_thread_name(i))) { - num_threads[i] = op->extent(); + num_threads[i] = simplify(op->extent()); } if (ends_with(op->name, gpu_block_name(i))) { - num_blocks[i] = op->extent(); + num_blocks[i] = simplify(op->extent()); } } From 2d85ad429608b7c56bfc4c0bc40396528b0a391e Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 8 Dec 2025 10:05:41 -0800 Subject: [PATCH 4/5] Give op->max + 1 a new name for clarity --- src/PartitionLoops.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index 1c77f014eb7b..1a8e001f0e1c 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -707,8 +707,8 @@ class PartitionLoops : public IRMutator { Stmt prologue = MakeSimplifications(prologue_simps).mutate(body); Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body); - bool make_prologue = !equal(prologue, simpler_body); - bool make_epilogue = !equal(epilogue, simpler_body); + const bool make_prologue = !equal(prologue, simpler_body); + const bool make_epilogue = !equal(epilogue, simpler_body); // Recurse on the middle section. simpler_body = mutate(simpler_body); @@ -721,10 +721,11 @@ class PartitionLoops : public IRMutator { } // Construct variables for the bounds of the simplified middle section - Expr min_steady = op->min, max_steady = op->max + 1; + const Expr original_max_plus_one = op->max + 1; + Expr min_steady = op->min, max_steady = original_max_plus_one; Expr prologue_val, epilogue_val; - string prologue_name = unique_name(op->name + ".prologue"); - string epilogue_name = unique_name(op->name + ".epilogue"); + const string prologue_name = unique_name(op->name + ".prologue"); + const string epilogue_name = unique_name(op->name + ".epilogue"); if (make_prologue) { // They'll simplify better if you put them in @@ -735,7 +736,7 @@ class PartitionLoops : public IRMutator { min_vals.push_back(op->min); prologue_val = fold_left(min_vals, Max::make); // Stop the prologue from running past the end of the loop - prologue_val = min(prologue_val, op->max + 1); + prologue_val = min(prologue_val, original_max_plus_one); // prologue_val = print(prologue_val, prologue_name); min_steady = Variable::make(Int(32), prologue_name); @@ -811,7 +812,7 @@ class PartitionLoops : public IRMutator { // epilogue_val = print(epilogue_val, op->name, "epilogue", op->min, op->max); stmt = LetStmt::make(epilogue_name, epilogue_val, stmt); } else { - epilogue_val = op->max + 1; + epilogue_val = original_max_plus_one; } if (make_prologue) { // Uncomment to include code that prints the prologue value From 2e203223a354d115e5d99eb14e3ffc5f83e2b815 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 8 Dec 2025 10:09:12 -0800 Subject: [PATCH 5/5] Update src/VectorizeLoops.cpp Co-authored-by: Alex Reinking --- src/VectorizeLoops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index e8935aca6257..2d149adbaf20 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -972,8 +972,8 @@ class VectorSubs : public IRMutator { body = IfThenElse::make(likely(var <= max), body); } - Expr extent = simplify((max - min) + 1); if (op->for_type == ForType::Vectorized) { + Expr extent = simplify((max - min) + 1); const IntImm *extent_int = extent.as(); internal_assert(extent_int) << "Vectorized for loop extent should have been rewritten to a constant\n";