diff --git a/docs/src/tutorials/raising.md b/docs/src/tutorials/raising.md index ab7ab48486..38e39eb3f9 100644 --- a/docs/src/tutorials/raising.md +++ b/docs/src/tutorials/raising.md @@ -128,11 +128,13 @@ raising). ``` This IR has a nested loop, but that won't work nicely for GPUs/TPUs. Even for CPUs, XLA -often doens't do a great job with loops. By default, we will attempt to raise loops to a +often doens't do a great job with loops. We will attempt to raise loops to a tensor IR. ```@example raising_stablehlo -hlo = @code_hlo compute_attractive_force(positions_ra, masses_ra, 2.0f0) +hlo = @code_hlo compile_options=CompileOptions(; + disable_auto_batching_passes=false +) compute_attractive_force(positions_ra, masses_ra, 2.0f0) @assert !contains(repr(hlo), "stablehlo.while") #hide hlo ``` @@ -142,7 +144,9 @@ the values are identical. ```@example raising_stablehlo y_jl = compute_attractive_force(positions, masses, 2.0f0) -y_ra = @jit compute_attractive_force(positions_ra, masses_ra, 2.0f0) +y_ra = @jit compile_options=CompileOptions(; + disable_auto_batching_passes=false +) compute_attractive_force(positions_ra, masses_ra, 2.0f0) maximum(abs, Array(y_ra) .- y_jl) ``` diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index a96d24c8d2..5cf29833f4 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -212,7 +212,7 @@ function CompileOptions(; disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, disable_licm_optimization_passes::Bool=false, - disable_auto_batching_passes::Bool=false, + disable_auto_batching_passes::Bool=true, ) optimization_passes isa Bool && (optimization_passes = ifelse(optimization_passes, :all, :none))