diff --git a/src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp b/src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp index e8951ed32b..dff033b3c5 100644 --- a/src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp +++ b/src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp @@ -12,7 +12,11 @@ #pragma GCC diagnostic pop #endif +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" +#include "src/enzyme_ad/jax/Passes/StructuredTensors.h" #include "src/enzyme_ad/jax/Utils.h" using namespace mlir; @@ -224,11 +228,6 @@ bool WhileLoopInfo::isConstantValue(Value v, llvm::APInt &constVal) { void WhileLoopInfo::propagateAffineIndexInfo() { auto inductionVar = getInductionVariable(); - SmallVector worklist; - DenseSet visited; - - worklist.push_back(inductionVar); - auto inductionType = inductionVar.getType(); unsigned bitWidth = 64; if (auto tensorType = dyn_cast(inductionType)) { @@ -239,13 +238,29 @@ void WhileLoopInfo::propagateAffineIndexInfo() { llvm::APInt baseScaling(bitWidth, 1, true); llvm::APInt baseOffset(bitWidth, 0, true); - affineIndexInfo[inductionVar] = AffineIndexInfo{baseScaling, baseOffset}; + SmallVector newPropagated; + + propagateAffineIndexInfo( + inductionVar, AffineIndexInfo{baseScaling, baseOffset}, newPropagated); + return; +} + +void WhileLoopInfo::propagateAffineIndexInfo( + Value v, AffineIndexInfo vInfo, SmallVectorImpl &newPropagated) { + SmallVector worklist; + worklist.push_back(v); + affineIndexInfo[v] = vInfo; + + auto bitWidth = vInfo.scale.getBitWidth(); + APInt baseScaling(vInfo.scale.getBitWidth(), 1, true); + APInt baseOffset(vInfo.offset.getBitWidth(), 0, true); while (!worklist.empty()) { auto cur = worklist.pop_back_val(); - if (visited.contains(cur)) + if (affineIndexPropagationVisited.contains(cur)) { continue; - visited.insert(cur); + } + affineIndexPropagationVisited.insert(cur); AffineIndexInfo curInfo = affineIndexInfo[cur]; @@ -284,14 +299,75 @@ void WhileLoopInfo::propagateAffineIndexInfo() { } else if (auto negOp = dyn_cast(user)) { newInfo = updateAffineIndexInfo(curInfo, -baseScaling, baseOffset); result = negOp.getResult(); + } else if (auto reshapeOp = dyn_cast(user)) { + if (cast(reshapeOp.getType()).getNumElements() == 1) { + newInfo = updateAffineIndexInfo(curInfo, baseScaling, baseOffset); + result = reshapeOp.getResult(); + } } if (result && !affineIndexInfo.contains(result)) { affineIndexInfo[result] = newInfo; + newPropagated.push_back(result); worklist.push_back(result); } } } + + int64_t totalIterations = 0; + bool anyNewPropagated; + do { + anyNewPropagated = false; + // if any slice operand is an iota, then we can try to infer the offset + // and scale + op.getBody().front().walk([&](stablehlo::DynamicSliceOp sliceOp) { + // Skip if we've already processed this slice's result + if (affineIndexInfo.contains(sliceOp.getResult())) { + return WalkResult::advance(); + } + + if (cast(sliceOp.getType()).getNumElements() != 1) { + return WalkResult::advance(); + } + + int64_t sliceDim = -1; + for (int64_t i = 0; i < sliceOp.getSliceSizes().size(); i++) { + if (matchPattern(sliceOp.getStartIndices()[i], m_Zero())) { + continue; + } + if (sliceDim != -1) { + return WalkResult::advance(); // can't do anything here + } + sliceDim = i; + } + + auto iotaDetection = detectIotaLikeTensor(sliceOp.getOperand()); + + if (iotaDetection && sliceDim == iotaDetection.value().dimension) { + anyNewPropagated = true; + auto indexInfo = affineIndexInfo[sliceOp.getStartIndices()[sliceDim]]; + auto offset = indexInfo.offset.getSExtValue(); + auto iotaStart = iotaDetection.value().start; + auto iotaScale = iotaDetection.value().scale; + // The slice result is: iotaScale * (indexInfo.scale * i + + // indexInfo.offset) + iotaStart + // = (iotaScale * indexInfo.scale) * i + (iotaScale * + // indexInfo.offset + iotaStart) + auto newScale = indexInfo.scale * iotaScale; + auto newOffset = iotaScale * offset + iotaStart; + + propagateAffineIndexInfo( + sliceOp.getResult(), + WhileLoopInfo::AffineIndexInfo{ + newScale, + llvm::APInt(indexInfo.offset.getBitWidth(), newOffset)}, + newPropagated); + } + + return WalkResult::advance(); + }); + totalIterations++; + } while (anyNewPropagated && totalIterations < 4); } bool WhileLoopInfo::isConstantAcrossIterations(Value v, bool checkOperands) { diff --git a/src/enzyme_ad/jax/Implementations/WhileLoopInfo.h b/src/enzyme_ad/jax/Implementations/WhileLoopInfo.h index 92e5a31fe7..3d5986f42e 100644 --- a/src/enzyme_ad/jax/Implementations/WhileLoopInfo.h +++ b/src/enzyme_ad/jax/Implementations/WhileLoopInfo.h @@ -59,6 +59,9 @@ struct WhileLoopInfo { Value getNumIters(OpBuilder &builder); void propagateAffineIndexInfo(); + void propagateAffineIndexInfo(Value v, AffineIndexInfo curInfo, + SmallVectorImpl &newPropagated); + llvm::MapVector getAffineIndexInfo() { return affineIndexInfo; } @@ -103,6 +106,7 @@ struct WhileLoopInfo { std::optional constStep; llvm::MapVector affineIndexInfo; + DenseSet affineIndexPropagationVisited; void computeConstantValues(); diff --git a/src/enzyme_ad/jax/Passes/AutoBatching.cpp b/src/enzyme_ad/jax/Passes/AutoBatching.cpp index a6c2c8900d..706bed8689 100644 --- a/src/enzyme_ad/jax/Passes/AutoBatching.cpp +++ b/src/enzyme_ad/jax/Passes/AutoBatching.cpp @@ -815,8 +815,9 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl( for (auto [value, affineIndexInfo] : affineIndexInfoMap) { for (auto user : value.getUsers()) { - if (user->getBlock() != &whileBody || seenOps.contains(user)) + if (user->getBlock() != &whileBody || seenOps.contains(user)) { continue; + } seenOps.insert(user); @@ -832,81 +833,6 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl( } } - bool anyOpRewritten = false; - - // iota [idx] where iota starts at 0 and iter var also starts at 0 - // replace this with idx - // If we do a successful rewrite here, we remove the DynamicSliceInfo from - // the candidateSlices vector (a later invocation will handle the rest) - SmallVector retainedSlices; - for (auto [i, slice] : llvm::enumerate(candidateSlices)) { - if (slice.dimensions.size() != 1) { - retainedSlices.push_back(slice); - continue; - } - - auto iotaDetection = detectIotaLikeTensor(slice.sliceOp.getOperand()); - - if (iotaDetection && - slice.dimensions[0] == iotaDetection.value().dimension) { - anyOpRewritten = true; - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(slice.sliceOp); - Value newOperand = info.getInductionVariable(); - auto sliceType = - cast(slice.sliceOp.getResult().getType()); - auto outElemType = sliceType.getElementType(); - auto opElemType = cast(newOperand.getType()).getElementType(); - - // iotaTensor[i] = iotaStart + i - // indexing with `scale * indVar + offset` - // result = scale * indVar + (iotaStart + offset) - - auto affineIndexInfo = - affineIndexInfoMap[slice.sliceOp - .getStartIndices()[slice.dimensions[0]]]; - - auto scalarType = RankedTensorType::get({}, opElemType); - - if (!affineIndexInfo.scale.isOne()) { - newOperand = stablehlo::MulOp::create( - rewriter, slice.sliceOp.getLoc(), - stablehlo::ConstantOp::create( - rewriter, slice.sliceOp.getLoc(), scalarType, - cast(makeAttr( - scalarType, affineIndexInfo.scale.getSExtValue()))), - newOperand); - } - - auto indexOffset = affineIndexInfo.offset.getSExtValue(); - auto iotaStart = iotaDetection.value().start; - auto offset = indexOffset + iotaStart; - - if (offset != 0) { - newOperand = stablehlo::AddOp::create( - rewriter, slice.sliceOp.getLoc(), newOperand, - stablehlo::ConstantOp::create( - rewriter, slice.sliceOp.getLoc(), scalarType, - cast(makeAttr(scalarType, offset)))); - } - - if (opElemType != outElemType) { - newOperand = stablehlo::ConvertOp::create( - rewriter, slice.sliceOp.getLoc(), - RankedTensorType::get({}, outElemType), newOperand) - .getResult(); - } - - rewriter.replaceOpWithNewOp( - slice.sliceOp, sliceType, newOperand, - rewriter.getDenseI64ArrayAttr({})); - } else { - retainedSlices.push_back(slice); - } - } - candidateSlices = std::move(retainedSlices); - // Create a map of user operations to their corresponding dynamic slices llvm::MapVector> userOpToSlicesMap; for (auto ds : candidateSlices) { @@ -945,7 +871,9 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl( } if (userOpToSlicesMap.empty()) - return anyOpRewritten ? success() : failure(); + return failure(); + + bool anyOpRewritten = false; for (auto &[op, slices] : userOpToSlicesMap) { bool avoidBatching = diff --git a/src/enzyme_ad/jax/Passes/StructuredTensors.cpp b/src/enzyme_ad/jax/Passes/StructuredTensors.cpp index b4c58177bb..49cbd706f4 100644 --- a/src/enzyme_ad/jax/Passes/StructuredTensors.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredTensors.cpp @@ -80,7 +80,8 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp, auto isIotaLikeTensor = detectIotaLikeTensor(indices); if (isIotaLikeTensor) { auto iotaLikeTensor = isIotaLikeTensor.value(); - if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0) { + if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0 && + iotaLikeTensor.scale == 1) { *outUpdates = updates; return absl::OkStatus(); } @@ -101,6 +102,7 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { struct ChainItem { mlir::Operation *op; int64_t offset; // only populated for AddOp/SubtractOp + int64_t scale; // only populated for MulOp }; // Build a chain of operations from startOp to the base case @@ -114,7 +116,7 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { // check if we found a base case if (isa(currentOp)) { - chain.push_back({currentOp, 0}); + chain.push_back({currentOp, 0, 1}); break; } @@ -125,7 +127,7 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { // TODO: we might want to support broadcast_in_dim / insert_dims / drop_dims // as well if (isa(currentOp)) { - chain.push_back({currentOp, 0}); + chain.push_back({currentOp, 0, 1}); nextOp = currentOp->getOperand(0).getDefiningOp(); } else if (auto convertOp = dyn_cast(currentOp)) { // if operand of convertOp is not a integer, then return std::nullopt @@ -133,15 +135,15 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { cast(convertOp.getOperand().getType()) .getElementType())) return std::nullopt; - chain.push_back({currentOp, 0}); + chain.push_back({currentOp, 0, 1}); nextOp = convertOp.getOperand().getDefiningOp(); } else if (auto addOp = dyn_cast(currentOp)) { APInt offsetVal; if (matchPattern(addOp.getRhs(), m_ConstantInt(&offsetVal))) { - chain.push_back({currentOp, offsetVal.getSExtValue()}); + chain.push_back({currentOp, offsetVal.getSExtValue(), 1}); nextOp = addOp.getLhs().getDefiningOp(); } else if (matchPattern(addOp.getLhs(), m_ConstantInt(&offsetVal))) { - chain.push_back({currentOp, offsetVal.getSExtValue()}); + chain.push_back({currentOp, offsetVal.getSExtValue(), 1}); nextOp = addOp.getRhs().getDefiningOp(); } else { return std::nullopt; @@ -149,11 +151,22 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { } else if (auto subOp = dyn_cast(currentOp)) { APInt offsetVal; if (matchPattern(subOp.getRhs(), m_ConstantInt(&offsetVal))) { - chain.push_back({currentOp, -offsetVal.getSExtValue()}); + chain.push_back({currentOp, -offsetVal.getSExtValue(), 1}); nextOp = subOp.getLhs().getDefiningOp(); } else { return std::nullopt; } + } else if (auto mulOp = dyn_cast(currentOp)) { + APInt scaleVal; + if (matchPattern(mulOp.getRhs(), m_ConstantInt(&scaleVal))) { + chain.push_back({currentOp, 0, scaleVal.getSExtValue()}); + nextOp = mulOp.getLhs().getDefiningOp(); + } else if (matchPattern(mulOp.getLhs(), m_ConstantInt(&scaleVal))) { + chain.push_back({currentOp, 0, scaleVal.getSExtValue()}); + nextOp = mulOp.getRhs().getDefiningOp(); + } else { + return std::nullopt; + } } else { // unsupported op return std::nullopt; } @@ -169,7 +182,7 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { if (auto iotaOp = dyn_cast(chain.back().op)) { auto iotaType = cast(iotaOp.getResult().getType()); auto iotaDim = static_cast(iotaOp.getIotaDimension()); - result = IotaLikeTensor{0, iotaType.getShape()[iotaDim], iotaDim, iotaType}; + result = IotaLikeTensor{0, iotaDim, 1, iotaType}; } else if (auto constantOp = dyn_cast(chain.back().op)) { auto denseAttr = cast(constantOp.getValue()); @@ -191,6 +204,7 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { for (int64_t dim = 0; dim < constType.getRank(); dim++) { bool isIotaAlongDim = true; std::optional detectedStart; + std::optional detectedScale; SmallVector indices(constType.getRank(), 0); int64_t numElements = constType.getNumElements(); @@ -207,9 +221,18 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { if (!detectedStart) { detectedStart = actualValue; + } else if (!detectedScale && indices[dim] == 1) { + // Detect scale from the second element along this dimension + detectedScale = actualValue - detectedStart.value(); + if (detectedScale.value() == 0) { + // Scale of 0 means all values are the same, not an iota + isIotaAlongDim = false; + break; + } } - int64_t expectedValue = detectedStart.value() + indices[dim]; + int64_t expectedValue = + detectedStart.value() + indices[dim] * detectedScale.value_or(1); if (actualValue != expectedValue) { isIotaAlongDim = false; break; @@ -218,9 +241,8 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { if (isIotaAlongDim && detectedStart) { isIotaLike = true; - result = - IotaLikeTensor{detectedStart.value(), - detectedStart.value() + shape[dim], dim, constType}; + int64_t scale = detectedScale.value_or(1); + result = IotaLikeTensor{detectedStart.value(), dim, scale, constType}; break; } } @@ -249,6 +271,10 @@ std::optional detectIotaLikeTensor(mlir::Value tensor) { } else if (isa(item.op)) { result.start += item.offset; continue; + } else if (isa(item.op)) { + result.start *= item.scale; + result.scale *= item.scale; + continue; } assert(false && "reached unreachable case..."); diff --git a/src/enzyme_ad/jax/Passes/StructuredTensors.h b/src/enzyme_ad/jax/Passes/StructuredTensors.h index 5fff39ae9b..cdde530cf2 100644 --- a/src/enzyme_ad/jax/Passes/StructuredTensors.h +++ b/src/enzyme_ad/jax/Passes/StructuredTensors.h @@ -22,8 +22,8 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp, struct IotaLikeTensor { int64_t start; - int64_t limit; int64_t dimension; + int64_t scale = 1; // multiplicative factor applied to the iota mlir::RankedTensorType tensorType; }; diff --git a/test/lit_tests/autobatching/loop_bcast_inf_compile.mlir b/test/lit_tests/autobatching/loop_bcast_inf_compile.mlir new file mode 100644 index 0000000000..c1faab0420 --- /dev/null +++ b/test/lit_tests/autobatching/loop_bcast_inf_compile.mlir @@ -0,0 +1,49 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="any(inline,enzyme-hlo-generate-td{patterns=broadcast_in_dim_simplify<16>(1024);iota_simplify<16>(1024);while_simplify<1>(1);while_deadresult;while_op_induction_replacement;greedy_while_loop_batch_fission;add_const_prop;mul_const_prop;div_const_prop;sub_const_prop},transform-interpreter,enzyme-hlo-remove-transform)" %s | FileCheck %s + +module { + func.func private @"*_broadcast_scalar"(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { + %0 = stablehlo.convert %arg1 : (tensor) -> tensor + %1 = stablehlo.multiply %arg0, %0 : tensor + return %1, %arg0, %arg1 : tensor, tensor, tensor + } + func.func private @identity_broadcast_scalar(%arg0: tensor) -> tensor { + return %arg0 : tensor + } + func.func private @"/_broadcast_scalar"(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { + %0 = stablehlo.divide %arg0, %arg1 : tensor + return %0, %arg0, %arg1 : tensor, tensor, tensor + } + func.func @nnorm(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %c = stablehlo.constant dense<0> : tensor + %c_0 = stablehlo.constant dense<10> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + // CHECK: stablehlo.while + %1:2 = stablehlo.while(%iterArg = %c, %iterArg_2 = %0) : tensor, tensor<10xf32> attributes {enzyme.disable_mincut} + cond { + %3 = stablehlo.subtract %c_0, %c_1 : tensor + %4 = stablehlo.divide %3, %c_1 : tensor + %5 = stablehlo.add %4, %c_1 : tensor + %6 = stablehlo.compare LT, %iterArg, %5 : (tensor, tensor) -> tensor + stablehlo.return %6 : tensor + } do { + %3 = stablehlo.multiply %iterArg, %c_1 : tensor + %4 = stablehlo.add %c_1, %3 : tensor + %5 = stablehlo.add %iterArg, %c_1 : tensor + %7 = stablehlo.broadcast_in_dim %iterArg_2, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + // %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + %8 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<10xi64> + %9:3 = enzyme.batch @"*_broadcast_scalar"(%7, %8) {batch_shape = array} : (tensor<10xf32>, tensor<10xi64>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xi64>) + %10 = enzyme.batch @identity_broadcast_scalar(%iterArg_2) {batch_shape = array} : (tensor<10xf32>) -> tensor<10xf32> + %11 = stablehlo.reduce(%10 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<10xf32>, tensor) -> tensor + %12 = stablehlo.broadcast_in_dim %9#0, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + %13 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + %14 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<10xf32> + %15:3 = enzyme.batch @"/_broadcast_scalar"(%13, %14) {batch_shape = array} : (tensor<10xf32>, tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) + stablehlo.return %5, %15#0 : tensor, tensor<10xf32> + } + %2 = stablehlo.transpose %1#1, dims = [0] : (tensor<10xf32>) -> tensor<10xf32> + return %2 : tensor<10xf32> + } +}