Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 11, 2025

clearly one of the many introduced passes are busted...

@avik-pal
Copy link
Collaborator Author

somehow it crashes the LU tests... weird will investigate

@avik-pal
Copy link
Collaborator Author

one of the many reasons why we should setup reverse ci

@avik-pal
Copy link
Collaborator Author

the AD failures stem from us not defining batched AD for dot_general correctly

@avik-pal
Copy link
Collaborator Author

we also have a inf compile....

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 11, 2025

module @reactant_solve_w... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr)
  func.func private @batched_enzymexla_lapack_sgetrf_0(%arg0: tensor<32x10x10xf32> {enzymexla.memory_effects = []}) -> (tensor<32x10x10xf32>, tensor<32x10xi64>, tensor<32xi64>) attributes {enzymexla.memory_effects = []} {
    %c = stablehlo.constant dense<1> : tensor<i64>
    %c_0 = stablehlo.constant dense<32> : tensor<i64>
    %cst = arith.constant dense<0> : tensor<32xi64>
    %cst_1 = arith.constant dense<0> : tensor<32x10xi64>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x10x10xf32>
    %c_3 = stablehlo.constant dense<0> : tensor<i64>
    %c_4 = stablehlo.constant dense<-1> : tensor<32xi64>
    %c_5 = stablehlo.constant dense<-1> : tensor<32x10xi64>
    %0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<32x10x10xf32>) -> tensor<i32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<i32>) -> tensor<32xi32>
    %2 = stablehlo.convert %1 : (tensor<32xi32>) -> tensor<32xi64>
    %3 = stablehlo.get_dimension_size %arg0, dim = 2 : (tensor<32x10x10xf32>) -> tensor<i32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<i32>) -> tensor<32xi32>
    %5 = stablehlo.convert %4 : (tensor<32xi32>) -> tensor<32xi64>
    %6:4 = stablehlo.while(%iterArg = %c_3, %iterArg_6 = %cst_2, %iterArg_7 = %cst_1, %iterArg_8 = %cst) : tensor<i64>, tensor<32x10x10xf32>, tensor<32x10xi64>, tensor<32xi64>
    cond {
      %7 = stablehlo.compare  LT, %iterArg, %c_0 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %7 : tensor<i1>
    } do {
      %7 = stablehlo.add %iterArg, %c : tensor<i64>
      %8 = stablehlo.divide %iterArg, %c : tensor<i64>
      %9 = stablehlo.remainder %8, %c_0 : tensor<i64>
      %10 = stablehlo.dynamic_slice %5, %9, sizes = [1] : (tensor<32xi64>, tensor<i64>) -> tensor<1xi64>
      %11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor<i64>
      %12 = stablehlo.dynamic_slice %arg0, %9, %c_3, %c_3, sizes = [1, 10, 10] : (tensor<32x10x10xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x10x10xf32>
      %13 = stablehlo.reshape %12 : (tensor<1x10x10xf32>) -> tensor<10x10xf32>
      %14 = stablehlo.dynamic_slice %2, %9, sizes = [1] : (tensor<32xi64>, tensor<i64>) -> tensor<1xi64>
      %15 = stablehlo.reshape %14 : (tensor<1xi64>) -> tensor<i64>
      %16 = stablehlo.dynamic_slice %c_5, %9, %c_3, sizes = [1, 10] : (tensor<32x10xi64>, tensor<i64>, tensor<i64>) -> tensor<1x10xi64>
      %17 = stablehlo.reshape %16 : (tensor<1x10xi64>) -> tensor<10xi64>
      %18 = stablehlo.dynamic_slice %c_4, %9, sizes = [1] : (tensor<32xi64>, tensor<i64>) -> tensor<1xi64>
      %19 = stablehlo.reshape %18 : (tensor<1xi64>) -> tensor<i64>
      %20:3 = stablehlo.custom_call @enzymexla_compile_cpu(%15, %11, %13, %15, %17, %19) {api_version = 3 : i32, backend_config = "\00\90CB\CFq\00\00\00\00\00\00\00\00\00\00", operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<i64>, tensor<i64>, tensor<10x10xf32>, tensor<i64>, tensor<10xi64>, tensor<i64>) -> (tensor<10x10xf32>, tensor<10xi64>, tensor<i64>)
      %21 = stablehlo.reshape %20#0 : (tensor<10x10xf32>) -> tensor<1x10x10xf32>
      %22 = stablehlo.dynamic_update_slice %iterArg_6, %21, %9, %c_3, %c_3 : (tensor<32x10x10xf32>, tensor<1x10x10xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<32x10x10xf32>
      %23 = stablehlo.reshape %20#1 : (tensor<10xi64>) -> tensor<1x10xi64>
      %24 = stablehlo.dynamic_update_slice %iterArg_7, %23, %9, %c_3 : (tensor<32x10xi64>, tensor<1x10xi64>, tensor<i64>, tensor<i64>) -> tensor<32x10xi64>
      %25 = stablehlo.reshape %20#2 : (tensor<i64>) -> tensor<1xi64>
      %26 = stablehlo.dynamic_update_slice %iterArg_8, %25, %9 : (tensor<32xi64>, tensor<1xi64>, tensor<i64>) -> tensor<32xi64>
      stablehlo.return %7, %22, %24, %26 : tensor<i64>, tensor<32x10x10xf32>, tensor<32x10xi64>, tensor<32xi64>
    }
    return %6#1, %6#2, %6#3 : tensor<32x10x10xf32>, tensor<32x10xi64>, tensor<32xi64>
  }
  func.func @main(%arg0: tensor<32x10x10xf32> {enzymexla.memory_effects = []}, %arg1: tensor<32x10xf32> {enzymexla.memory_effects = []}) -> tensor<32x10xf32> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %c_0 = stablehlo.constant dense<10> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<32x10xi64>
    %c_2 = stablehlo.constant dense<0> : tensor<i32>
    %c_3 = stablehlo.constant dense<1> : tensor<32x10x1xi64>
    %0 = stablehlo.transpose %arg0, dims = [0, 2, 1] : (tensor<32x10x10xf32>) -> tensor<32x10x10xf32>
    %1:3 = call @batched_enzymexla_lapack_sgetrf_0(%0) : (tensor<32x10x10xf32>) -> (tensor<32x10x10xf32>, tensor<32x10xi64>, tensor<32xi64>)
    %2 = stablehlo.subtract %1#1, %c_1 : tensor<32x10xi64>
    %3 = stablehlo.iota dim = 1 : tensor<32x10xi64>
    %4:2 = stablehlo.while(%iterArg = %c_2, %iterArg_4 = %3) : tensor<i32>, tensor<32x10xi64>
    cond {
      %15 = stablehlo.compare  LT, %iterArg, %c_0 : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %15 : tensor<i1>
    } do {
      %15 = stablehlo.add %iterArg, %c : tensor<i32>
      %16 = stablehlo.dynamic_slice %2, %c_2, %iterArg, sizes = [32, 1] : (tensor<32x10xi64>, tensor<i32>, tensor<i32>) -> tensor<32x1xi64>
      %17 = stablehlo.dynamic_slice %iterArg_4, %c_2, %iterArg, sizes = [32, 1] : (tensor<32x10xi64>, tensor<i32>, tensor<i32>) -> tensor<32x1xi64>
      %18 = "stablehlo.gather"(%iterArg_4, %16) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], operand_batching_dims = [0], start_indices_batching_dims = [0], start_index_map = [1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<32x10xi64>, tensor<32x1xi64>) -> tensor<32x1xi64>
      %19 = stablehlo.dynamic_update_slice %iterArg_4, %18, %c_2, %iterArg : (tensor<32x10xi64>, tensor<32x1xi64>, tensor<i32>, tensor<i32>) -> tensor<32x10xi64>
      %20 = stablehlo.reshape %17 : (tensor<32x1xi64>) -> tensor<32xi64>
      %21 = "stablehlo.scatter"(%19, %16, %20) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [1], input_batching_dims = [0], scatter_indices_batching_dims = [0], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
      ^bb0(%arg2: tensor<i64>, %arg3: tensor<i64>):
        stablehlo.return %arg3 : tensor<i64>
      }) : (tensor<32x10xi64>, tensor<32x1xi64>, tensor<32xi64>) -> tensor<32x10xi64>
      stablehlo.return %15, %21 : tensor<i32>, tensor<32x10xi64>
    }
    %5 = stablehlo.add %4#1, %c_1 : tensor<32x10xi64>
    %6 = stablehlo.convert %5 : (tensor<32x10xi64>) -> tensor<32x10xi32>
    %7 = stablehlo.reshape %arg1 : (tensor<32x10xf32>) -> tensor<32x10x1xf32>
    %8 = stablehlo.reshape %6 : (tensor<32x10xi32>) -> tensor<32x10x1xi32>
    %9 = stablehlo.convert %8 : (tensor<32x10x1xi32>) -> tensor<32x10x1xi64>
    %10 = stablehlo.subtract %9, %c_3 : tensor<32x10x1xi64>
    %11 = "stablehlo.gather"(%7, %10) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [1], operand_batching_dims = [0], start_indices_batching_dims = [0], start_index_map = [1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1>}> : (tensor<32x10x1xf32>, tensor<32x10x1xi64>) -> tensor<32x10x1xf32>
    %12 = "stablehlo.triangular_solve"(%1#0, %11) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = true}> : (tensor<32x10x10xf32>, tensor<32x10x1xf32>) -> tensor<32x10x1xf32>
    %13 = "stablehlo.triangular_solve"(%1#0, %12) <{left_side = true, lower = false, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<32x10x10xf32>, tensor<32x10x1xf32>) -> tensor<32x10x1xf32>
    %14 = stablehlo.reshape %13 : (tensor<32x10x1xf32>) -> tensor<32x10xf32>
    return %14 : tensor<32x10xf32>
  }
}

@avik-pal avik-pal merged commit ff5647e into main Dec 12, 2025
67 of 70 checks passed
@avik-pal avik-pal deleted the ap/bump_jll2 branch December 12, 2025 04:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants