Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/bgu/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ $(GENERATOR_BIN)/bgu.generator: bgu_generator.cpp $(GENERATOR_DEPS)

$(BIN)/%/bgu.a: $(GENERATOR_BIN)/bgu.generator
@mkdir -p $(@D)
$< -g bgu -f bgu -o $(BIN)/$* target=$*-no_runtime
$< -g bgu -f bgu -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime

$(BIN)/%/bgu_auto_schedule.a: $(GENERATOR_BIN)/bgu.generator
@mkdir -p $(@D)
$< -g bgu -f bgu_auto_schedule -o $(BIN)/$* target=$*-no_runtime autoscheduler=Mullapudi2016
$< -g bgu -f bgu_auto_schedule -o $(BIN)/$* -e $(GENERATOR_OUTPUTS) target=$*-no_runtime autoscheduler=Mullapudi2016

$(BIN)/%/runtime.a: $(GENERATOR_BIN)/bgu.generator
@mkdir -p $(@D)
Expand Down
2 changes: 1 addition & 1 deletion src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FindBuffers : public IRGraphVisitor {

void visit(const For *op) override {
op->min.accept(this);
op->extent.accept(this);
op->max.accept(this);
bool old = in_device_loop;
if (op->device_api != DeviceAPI::None &&
op->device_api != DeviceAPI::Host) {
Expand Down
19 changes: 8 additions & 11 deletions src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
Expr inner = Variable::make(Int(32), prefix + split.inner);
Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
Expr old_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expr old_extent = (old_max - old_min) + 1;

dim_extent_alignment[split.inner] = split.factor;

Expand Down Expand Up @@ -135,10 +135,10 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
// Define the inner and outer in terms of the fused var
Expr fused = Variable::make(Int(32), prefix + split.old_var);
Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min");
Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max");
Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min");
Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent");

const Expr &factor = inner_extent;
const Expr &factor = (inner_max - inner_min) + 1;
Expr inner = fused % factor + inner_min;
Expr outer = fused / factor + outer_min;

Expand Down Expand Up @@ -169,7 +169,6 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
// Define the bounds on the split dimensions using the bounds on the function args.
vector<std::pair<string, Expr>> let_stmts;

Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
switch (split.split_type) {
Expand All @@ -178,24 +177,22 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor;
let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.inner + ".loop_max", inner_extent - 1);
let_stmts.emplace_back(prefix + split.inner + ".loop_extent", inner_extent);
let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent);
} break;
case Split::FuseVars: {
// Define bounds on the fused var using the bounds on the inner and outer
Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent");
Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent");
Expr fused_extent = inner_extent * outer_extent;
Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min");
Expr inner_max = Variable::make(Int(32), prefix + split.inner + ".loop_max");
Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min");
Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max");
Expr fused_extent = (inner_max - inner_min + 1) * (outer_max - outer_min + 1);
let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1);
let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent);
} break;
case Split::RenameVar:
let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent);
break;
}

Expand Down
12 changes: 5 additions & 7 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NoOpCollapsingMutator : public IRMutator {
if (is_no_op(body)) {
return body;
} else {
return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body);
return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body);
}
}

Expand Down Expand Up @@ -752,11 +752,10 @@ class InjectRingBuffering : public IRMutator {

struct Loop {
std::string name;
Expr min;
Expr extent;

Loop(std::string n, Expr m, Expr e)
: name(std::move(n)), min(std::move(m)), extent(std::move(e)) {
Loop(std::string n, Expr e)
: name(std::move(n)), extent(std::move(e)) {
}
};

Expand All @@ -778,8 +777,7 @@ class InjectRingBuffering : public IRMutator {
int loop_index = hoist_storage_loop_index[op->name] + 1;
Expr current_index = Variable::make(Int(32), loops[loop_index].name);
while (++loop_index < (int)loops.size()) {
current_index = current_index *
(loops[loop_index].extent - loops[loop_index].min) +
current_index = current_index * loops[loop_index].extent +
Variable::make(Int(32), loops[loop_index].name);
}
current_index = current_index % f.schedule().ring_buffer();
Expand Down Expand Up @@ -817,7 +815,7 @@ class InjectRingBuffering : public IRMutator {
}

Stmt visit(const For *op) override {
loops.emplace_back(op->name, op->min, op->extent);
loops.emplace_back(op->name, op->extent());
Stmt mutated = IRMutator::visit(op);
loops.pop_back();
return mutated;
Expand Down
12 changes: 6 additions & 6 deletions src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class BoundLoops : public IRMutator {
}

Stmt visit(const For *op) override {
if (is_const(op->extent)) {
Expr extent = simplify(op->extent());
if (is_const(extent)) {
// Nothing needs to be done
return IRMutator::visit(op);
}

if (op->for_type == ForType::Unrolled ||
op->for_type == ForType::Vectorized) {
// Give it one last chance to simplify to an int
Expr extent = simplify(op->extent);
Stmt body = op->body;
const IntImm *e = extent.as<IntImm>();

Expand Down Expand Up @@ -82,8 +82,8 @@ class BoundLoops : public IRMutator {
if (extent_upper.defined()) {
e = extent_upper.as<IntImm>();
body =
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <
op->min + op->extent),
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <=
op->max),
body);
}
}
Expand All @@ -93,7 +93,7 @@ class BoundLoops : public IRMutator {
// to a serial loop.
user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n";
body = mutate(body);
return For::make(op->name, op->min, op->extent,
return For::make(op->name, op->min, op->max,
ForType::Serial, op->partition_policy, op->device_api, std::move(body));
}

Expand All @@ -103,7 +103,7 @@ class BoundLoops : public IRMutator {
<< "Loop over " << op->name << " has extent " << extent << ".\n";
body = mutate(body);

return For::make(op->name, op->min, e,
return For::make(op->name, op->min, (op->min + e) - 1,
op->for_type, op->partition_policy, op->device_api, std::move(body));
} else {
return IRMutator::visit(op);
Expand Down
2 changes: 1 addition & 1 deletion src/BoundSmallAllocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class BoundSmallAllocations : public IRMutator {

Stmt visit(const For *op) override {
Interval min_bounds = find_constant_bounds(op->min, scope);
Interval max_bounds = find_constant_bounds(op->min + op->extent - 1, scope);
Interval max_bounds = find_constant_bounds(op->max, scope);
Interval b = Interval::make_union(min_bounds, max_bounds);
b.min = simplify(b.min);
b.max = simplify(b.max);
Expand Down
20 changes: 4 additions & 16 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2964,23 +2964,11 @@ class BoxesTouched : public IRGraphVisitor {
TRACK_BOXES_TOUCHED_INFO("var:", op->name);
if (consider_calls) {
op->min.accept(this);
op->extent.accept(this);
op->max.accept(this);
}

Expr min_val, max_val;
if (const Interval *in = scope.find(op->name + ".loop_min")) {
min_val = in->min;
} else {
min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min;
}

if (const Interval *in = scope.find(op->name + ".loop_max")) {
max_val = in->max;
} else {
max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max;
max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max;
max_val -= 1;
}
Expr min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min;
Expr max_val = bounds_of_expr_in_scope(op->max, scope, func_bounds).max;

push_var(op->name);
{
Expand Down Expand Up @@ -3819,7 +3807,7 @@ void bounds_test() {
Buffer<int32_t> in(10);
in.set_name("input");

Stmt loop = For::make("x", 3, 10, ForType::Serial, Partition::Auto, DeviceAPI::Host,
Stmt loop = For::make("x", 3, 12, ForType::Serial, Partition::Auto, DeviceAPI::Host,
Provide::make("output",
{Add::make(Call::make(in, input_site_1),
Call::make(in, input_site_2))},
Expand Down
9 changes: 3 additions & 6 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ class BoundsOfInnerVar : public IRVisitor {
}

void visit(const For *op) override {
// At this stage of lowering, loop_min and loop_max
// conveniently exist in scope.
Interval in(Variable::make(Int(32), op->name + ".loop_min"),
Variable::make(Int(32), op->name + ".loop_max"));
Interval in(op->min, op->max);

if (op->name == var) {
result = in;
Expand Down Expand Up @@ -1308,7 +1305,7 @@ class BoundsInference : public IRMutator {
}
}

return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, body);
return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, body);
}

Scope<> let_vars_in_scope;
Expand Down Expand Up @@ -1392,7 +1389,7 @@ Stmt bounds_inference(Stmt s,
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
s = For::make("<outermost>", 0, 1, ForType::Serial, Partition::Never, DeviceAPI::None, s);
s = For::make("<outermost>", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s);

s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
outputs, func_bounds, target)
Expand Down
43 changes: 4 additions & 39 deletions src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,10 @@ class CanonicalizeGPUVars : public IRMutator {
return name;
}

std::string canonicalize_let(const std::string &name) {
if (ends_with(name, ".loop_max")) {
return find_replacement(".loop_max", name);
} else if (ends_with(name, ".loop_min")) {
return find_replacement(".loop_min", name);
} else if (ends_with(name, ".loop_extent")) {
return find_replacement(".loop_extent", name);
} else {
return name;
}
}

Stmt visit(const For *op) override {
std::string name = op->name;
Expr min = mutate(op->min);
Expr extent = mutate(op->extent);
Expr max = mutate(op->max);
Stmt body = mutate(op->body);

if ((op->for_type == ForType::GPUBlock) ||
Expand All @@ -130,44 +118,21 @@ class CanonicalizeGPUVars : public IRMutator {
gpu_vars.emplace(op->name, name);
Expr new_var = Variable::make(Int(32), name);
min = substitute(op->name, new_var, min);
extent = substitute(op->name, new_var, extent);
max = substitute(op->name, new_var, max);
body = substitute(op->name, new_var, body);
}
}

if ((name == op->name) &&
min.same_as(op->min) &&
extent.same_as(op->extent) &&
max.same_as(op->max) &&
body.same_as(op->body)) {
return op;
} else {
return For::make(name, min, extent, op->for_type, op->partition_policy, op->device_api, body);
return For::make(name, min, max, op->for_type, op->partition_policy, op->device_api, body);
}
}

Stmt visit(const LetStmt *op) override {
vector<std::pair<std::string, Expr>> lets;
Stmt result;

do {
lets.emplace_back(op->name, mutate(op->value));
result = op->body;
} while ((op = op->body.as<LetStmt>()));

result = mutate(result);

for (const auto &[var, value] : reverse_view(lets)) {
std::string name = canonicalize_let(var);
if (name != var) {
Expr new_var = Variable::make(Int(32), name);
result = substitute(var, new_var, result);
}
result = LetStmt::make(name, value, result);
}

return result;
}

Stmt visit(const IfThenElse *op) override {
Expr condition = mutate(op->condition);

Expand Down
2 changes: 1 addition & 1 deletion src/Closure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void Closure::visit(const LetStmt *op) {
void Closure::visit(const For *op) {
ScopedBinding<> p(ignore, op->name);
op->min.accept(this);
op->extent.accept(this);
op->max.accept(this);
op->body.accept(this);
}

Expand Down
5 changes: 2 additions & 3 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ void CodeGen_C::visit(const Atomic *op) {

void CodeGen_C::visit(const For *op) {
string id_min = print_expr(op->min);
string id_extent = print_expr(op->extent);
string id_max = print_expr(op->max);

if (op->for_type == ForType::Parallel) {
stream << get_indent() << "#pragma omp parallel for\n";
Expand All @@ -2237,8 +2237,7 @@ void CodeGen_C::visit(const For *op) {
<< " = " << id_min
<< "; "
<< print_name(op->name)
<< " < " << id_min
<< " + " << id_extent
<< " <= " << id_max
<< "; "
<< print_name(op->name)
<< "++)\n";
Expand Down
3 changes: 2 additions & 1 deletion src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,8 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
// generation time, emit code such that it can be patched some point
// later when calling D3DCompile() / halide_d3d12compute_run()
numthreads[index] = 0; // <-- 0 indicates 'undetermined'
const IntImm *int_limit = loop->extent.as<IntImm>();
Expr extent = simplify(loop->extent());
const IntImm *int_limit = extent.as<IntImm>();
if (nullptr != int_limit) {
numthreads[index] = int_limit->value;
user_assert(numthreads[index] > 0) << "For D3D12Compute, 'numthreads[" << index << "]' values must be greater than zero.\n";
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class InjectHVXLocks : public IRMutator {
if (uses_hvx) {
body = acquire_hvx_context(body, target);
body = substitute("uses_hvx", true, body);
Stmt new_for = For::make(op->name, op->min, op->extent, op->for_type,
Stmt new_for = For::make(op->name, op->min, op->max, op->for_type,
op->partition_policy, op->device_api, body);
Stmt prolog =
IfThenElse::make(uses_hvx_var, call_halide_qurt_hvx_unlock());
Expand Down Expand Up @@ -408,7 +408,7 @@ class InjectHVXLocks : public IRMutator {
// vector code
// halide_qurt_unlock
// }
s = For::make(op->name, op->min, op->extent, op->for_type,
s = For::make(op->name, op->min, op->max, op->for_type,
op->partition_policy, op->device_api, body);
}

Expand Down
Loading
Loading