Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 84 additions & 8 deletions src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
#pragma GCC diagnostic pop
#endif

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"

#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
#include "src/enzyme_ad/jax/Utils.h"

using namespace mlir;
Expand Down Expand Up @@ -224,11 +228,6 @@ bool WhileLoopInfo::isConstantValue(Value v, llvm::APInt &constVal) {
void WhileLoopInfo::propagateAffineIndexInfo() {
auto inductionVar = getInductionVariable();

SmallVector<Value> worklist;
DenseSet<Value> visited;

worklist.push_back(inductionVar);

auto inductionType = inductionVar.getType();
unsigned bitWidth = 64;
if (auto tensorType = dyn_cast<RankedTensorType>(inductionType)) {
Expand All @@ -239,13 +238,29 @@ void WhileLoopInfo::propagateAffineIndexInfo() {

llvm::APInt baseScaling(bitWidth, 1, true);
llvm::APInt baseOffset(bitWidth, 0, true);
affineIndexInfo[inductionVar] = AffineIndexInfo{baseScaling, baseOffset};
SmallVector<Value> newPropagated;

propagateAffineIndexInfo(
inductionVar, AffineIndexInfo{baseScaling, baseOffset}, newPropagated);
return;
}

void WhileLoopInfo::propagateAffineIndexInfo(
Value v, AffineIndexInfo vInfo, SmallVectorImpl<Value> &newPropagated) {
SmallVector<Value> worklist;
worklist.push_back(v);
affineIndexInfo[v] = vInfo;

auto bitWidth = vInfo.scale.getBitWidth();
APInt baseScaling(vInfo.scale.getBitWidth(), 1, true);
APInt baseOffset(vInfo.offset.getBitWidth(), 0, true);

while (!worklist.empty()) {
auto cur = worklist.pop_back_val();
if (visited.contains(cur))
if (affineIndexPropagationVisited.contains(cur)) {
continue;
visited.insert(cur);
}
affineIndexPropagationVisited.insert(cur);

AffineIndexInfo curInfo = affineIndexInfo[cur];

Expand Down Expand Up @@ -284,14 +299,75 @@ void WhileLoopInfo::propagateAffineIndexInfo() {
} else if (auto negOp = dyn_cast<stablehlo::NegOp>(user)) {
newInfo = updateAffineIndexInfo(curInfo, -baseScaling, baseOffset);
result = negOp.getResult();
} else if (auto reshapeOp = dyn_cast<stablehlo::ReshapeOp>(user)) {
if (cast<ShapedType>(reshapeOp.getType()).getNumElements() == 1) {
newInfo = updateAffineIndexInfo(curInfo, baseScaling, baseOffset);
result = reshapeOp.getResult();
}
}

if (result && !affineIndexInfo.contains(result)) {
affineIndexInfo[result] = newInfo;
newPropagated.push_back(result);
worklist.push_back(result);
}
}
}

int64_t totalIterations = 0;
bool anyNewPropagated;
do {
anyNewPropagated = false;
// if any slice operand is an iota, then we can try to infer the offset
// and scale
op.getBody().front().walk([&](stablehlo::DynamicSliceOp sliceOp) {
// Skip if we've already processed this slice's result
if (affineIndexInfo.contains(sliceOp.getResult())) {
return WalkResult::advance();
}

if (cast<ShapedType>(sliceOp.getType()).getNumElements() != 1) {
return WalkResult::advance();
}

int64_t sliceDim = -1;
for (int64_t i = 0; i < sliceOp.getSliceSizes().size(); i++) {
if (matchPattern(sliceOp.getStartIndices()[i], m_Zero())) {
continue;
}
if (sliceDim != -1) {
return WalkResult::advance(); // can't do anything here
}
sliceDim = i;
}

auto iotaDetection = detectIotaLikeTensor(sliceOp.getOperand());

if (iotaDetection && sliceDim == iotaDetection.value().dimension) {
anyNewPropagated = true;
auto indexInfo = affineIndexInfo[sliceOp.getStartIndices()[sliceDim]];
auto offset = indexInfo.offset.getSExtValue();
auto iotaStart = iotaDetection.value().start;
auto iotaScale = iotaDetection.value().scale;
// The slice result is: iotaScale * (indexInfo.scale * i +
// indexInfo.offset) + iotaStart
// = (iotaScale * indexInfo.scale) * i + (iotaScale *
// indexInfo.offset + iotaStart)
auto newScale = indexInfo.scale * iotaScale;
auto newOffset = iotaScale * offset + iotaStart;

propagateAffineIndexInfo(
sliceOp.getResult(),
WhileLoopInfo::AffineIndexInfo{
newScale,
llvm::APInt(indexInfo.offset.getBitWidth(), newOffset)},
newPropagated);
}

return WalkResult::advance();
});
totalIterations++;
} while (anyNewPropagated && totalIterations < 4);
}

bool WhileLoopInfo::isConstantAcrossIterations(Value v, bool checkOperands) {
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/Implementations/WhileLoopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct WhileLoopInfo {
Value getNumIters(OpBuilder &builder);

void propagateAffineIndexInfo();
void propagateAffineIndexInfo(Value v, AffineIndexInfo curInfo,
SmallVectorImpl<Value> &newPropagated);

llvm::MapVector<Value, AffineIndexInfo> getAffineIndexInfo() {
return affineIndexInfo;
}
Expand Down Expand Up @@ -103,6 +106,7 @@ struct WhileLoopInfo {
std::optional<int64_t> constStep;

llvm::MapVector<Value, AffineIndexInfo> affineIndexInfo;
DenseSet<Value> affineIndexPropagationVisited;

void computeConstantValues();

Expand Down
82 changes: 5 additions & 77 deletions src/enzyme_ad/jax/Passes/AutoBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,9 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(

for (auto [value, affineIndexInfo] : affineIndexInfoMap) {
for (auto user : value.getUsers()) {
if (user->getBlock() != &whileBody || seenOps.contains(user))
if (user->getBlock() != &whileBody || seenOps.contains(user)) {
continue;
}

seenOps.insert(user);

Expand All @@ -832,81 +833,6 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
}
}

bool anyOpRewritten = false;

// iota [idx] where iota starts at 0 and iter var also starts at 0
// replace this with idx
// If we do a successful rewrite here, we remove the DynamicSliceInfo from
// the candidateSlices vector (a later invocation will handle the rest)
SmallVector<DynamicSliceInfo> retainedSlices;
for (auto [i, slice] : llvm::enumerate(candidateSlices)) {
if (slice.dimensions.size() != 1) {
retainedSlices.push_back(slice);
continue;
}

auto iotaDetection = detectIotaLikeTensor(slice.sliceOp.getOperand());

if (iotaDetection &&
slice.dimensions[0] == iotaDetection.value().dimension) {
anyOpRewritten = true;

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(slice.sliceOp);
Value newOperand = info.getInductionVariable();
auto sliceType =
cast<RankedTensorType>(slice.sliceOp.getResult().getType());
auto outElemType = sliceType.getElementType();
auto opElemType = cast<TensorType>(newOperand.getType()).getElementType();

// iotaTensor[i] = iotaStart + i
// indexing with `scale * indVar + offset`
// result = scale * indVar + (iotaStart + offset)

auto affineIndexInfo =
affineIndexInfoMap[slice.sliceOp
.getStartIndices()[slice.dimensions[0]]];

auto scalarType = RankedTensorType::get({}, opElemType);

if (!affineIndexInfo.scale.isOne()) {
newOperand = stablehlo::MulOp::create(
rewriter, slice.sliceOp.getLoc(),
stablehlo::ConstantOp::create(
rewriter, slice.sliceOp.getLoc(), scalarType,
cast<ElementsAttr>(makeAttr(
scalarType, affineIndexInfo.scale.getSExtValue()))),
newOperand);
}

auto indexOffset = affineIndexInfo.offset.getSExtValue();
auto iotaStart = iotaDetection.value().start;
auto offset = indexOffset + iotaStart;

if (offset != 0) {
newOperand = stablehlo::AddOp::create(
rewriter, slice.sliceOp.getLoc(), newOperand,
stablehlo::ConstantOp::create(
rewriter, slice.sliceOp.getLoc(), scalarType,
cast<ElementsAttr>(makeAttr(scalarType, offset))));
}

if (opElemType != outElemType) {
newOperand = stablehlo::ConvertOp::create(
rewriter, slice.sliceOp.getLoc(),
RankedTensorType::get({}, outElemType), newOperand)
.getResult();
}

rewriter.replaceOpWithNewOp<stablehlo::BroadcastInDimOp>(
slice.sliceOp, sliceType, newOperand,
rewriter.getDenseI64ArrayAttr({}));
} else {
retainedSlices.push_back(slice);
}
}
candidateSlices = std::move(retainedSlices);

// Create a map of user operations to their corresponding dynamic slices
llvm::MapVector<Operation *, SmallVector<DynamicSliceInfo>> userOpToSlicesMap;
for (auto ds : candidateSlices) {
Expand Down Expand Up @@ -945,7 +871,9 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
}

if (userOpToSlicesMap.empty())
return anyOpRewritten ? success() : failure();
return failure();

bool anyOpRewritten = false;

for (auto &[op, slices] : userOpToSlicesMap) {
bool avoidBatching =
Expand Down
50 changes: 38 additions & 12 deletions src/enzyme_ad/jax/Passes/StructuredTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
auto isIotaLikeTensor = detectIotaLikeTensor(indices);
if (isIotaLikeTensor) {
auto iotaLikeTensor = isIotaLikeTensor.value();
if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0) {
if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0 &&
iotaLikeTensor.scale == 1) {
*outUpdates = updates;
return absl::OkStatus();
}
Expand All @@ -101,6 +102,7 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
struct ChainItem {
mlir::Operation *op;
int64_t offset; // only populated for AddOp/SubtractOp
int64_t scale; // only populated for MulOp
};

// Build a chain of operations from startOp to the base case
Expand All @@ -114,7 +116,7 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {

// check if we found a base case
if (isa<stablehlo::IotaOp, stablehlo::ConstantOp>(currentOp)) {
chain.push_back({currentOp, 0});
chain.push_back({currentOp, 0, 1});
break;
}

Expand All @@ -125,35 +127,46 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
// TODO: we might want to support broadcast_in_dim / insert_dims / drop_dims
// as well
if (isa<stablehlo::TransposeOp>(currentOp)) {
chain.push_back({currentOp, 0});
chain.push_back({currentOp, 0, 1});
nextOp = currentOp->getOperand(0).getDefiningOp();
} else if (auto convertOp = dyn_cast<stablehlo::ConvertOp>(currentOp)) {
// if operand of convertOp is not a integer, then return std::nullopt
if (!isa<mlir::IntegerType>(
cast<TensorType>(convertOp.getOperand().getType())
.getElementType()))
return std::nullopt;
chain.push_back({currentOp, 0});
chain.push_back({currentOp, 0, 1});
nextOp = convertOp.getOperand().getDefiningOp();
} else if (auto addOp = dyn_cast<stablehlo::AddOp>(currentOp)) {
APInt offsetVal;
if (matchPattern(addOp.getRhs(), m_ConstantInt(&offsetVal))) {
chain.push_back({currentOp, offsetVal.getSExtValue()});
chain.push_back({currentOp, offsetVal.getSExtValue(), 1});
nextOp = addOp.getLhs().getDefiningOp();
} else if (matchPattern(addOp.getLhs(), m_ConstantInt(&offsetVal))) {
chain.push_back({currentOp, offsetVal.getSExtValue()});
chain.push_back({currentOp, offsetVal.getSExtValue(), 1});
nextOp = addOp.getRhs().getDefiningOp();
} else {
return std::nullopt;
}
} else if (auto subOp = dyn_cast<stablehlo::SubtractOp>(currentOp)) {
APInt offsetVal;
if (matchPattern(subOp.getRhs(), m_ConstantInt(&offsetVal))) {
chain.push_back({currentOp, -offsetVal.getSExtValue()});
chain.push_back({currentOp, -offsetVal.getSExtValue(), 1});
nextOp = subOp.getLhs().getDefiningOp();
} else {
return std::nullopt;
}
} else if (auto mulOp = dyn_cast<stablehlo::MulOp>(currentOp)) {
APInt scaleVal;
if (matchPattern(mulOp.getRhs(), m_ConstantInt(&scaleVal))) {
chain.push_back({currentOp, 0, scaleVal.getSExtValue()});
nextOp = mulOp.getLhs().getDefiningOp();
} else if (matchPattern(mulOp.getLhs(), m_ConstantInt(&scaleVal))) {
chain.push_back({currentOp, 0, scaleVal.getSExtValue()});
nextOp = mulOp.getRhs().getDefiningOp();
} else {
return std::nullopt;
}
} else { // unsupported op
return std::nullopt;
}
Expand All @@ -169,7 +182,7 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
if (auto iotaOp = dyn_cast<stablehlo::IotaOp>(chain.back().op)) {
auto iotaType = cast<RankedTensorType>(iotaOp.getResult().getType());
auto iotaDim = static_cast<int64_t>(iotaOp.getIotaDimension());
result = IotaLikeTensor{0, iotaType.getShape()[iotaDim], iotaDim, iotaType};
result = IotaLikeTensor{0, iotaDim, 1, iotaType};
} else if (auto constantOp =
dyn_cast<stablehlo::ConstantOp>(chain.back().op)) {
auto denseAttr = cast<DenseElementsAttr>(constantOp.getValue());
Expand All @@ -191,6 +204,7 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
for (int64_t dim = 0; dim < constType.getRank(); dim++) {
bool isIotaAlongDim = true;
std::optional<int64_t> detectedStart;
std::optional<int64_t> detectedScale;

SmallVector<int64_t> indices(constType.getRank(), 0);
int64_t numElements = constType.getNumElements();
Expand All @@ -207,9 +221,18 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {

if (!detectedStart) {
detectedStart = actualValue;
} else if (!detectedScale && indices[dim] == 1) {
// Detect scale from the second element along this dimension
detectedScale = actualValue - detectedStart.value();
if (detectedScale.value() == 0) {
// Scale of 0 means all values are the same, not an iota
isIotaAlongDim = false;
break;
}
}

int64_t expectedValue = detectedStart.value() + indices[dim];
int64_t expectedValue =
detectedStart.value() + indices[dim] * detectedScale.value_or(1);
if (actualValue != expectedValue) {
isIotaAlongDim = false;
break;
Expand All @@ -218,9 +241,8 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {

if (isIotaAlongDim && detectedStart) {
isIotaLike = true;
result =
IotaLikeTensor{detectedStart.value(),
detectedStart.value() + shape[dim], dim, constType};
int64_t scale = detectedScale.value_or(1);
result = IotaLikeTensor{detectedStart.value(), dim, scale, constType};
break;
}
}
Expand Down Expand Up @@ -249,6 +271,10 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
} else if (isa<stablehlo::AddOp, stablehlo::SubtractOp>(item.op)) {
result.start += item.offset;
continue;
} else if (isa<stablehlo::MulOp>(item.op)) {
result.start *= item.scale;
result.scale *= item.scale;
continue;
}

assert(false && "reached unreachable case...");
Expand Down
Loading
Loading