Skip to content
Discussion options

You must be logged in to vote

The most definitive way to answer this is to use Ahead of time compilation in order to print the compiled HLO for your function. For example:

print(add_scalar.lower(vec_in, 1, 2).compile().as_text())
HloModule jit__add_scalar, is_scheduled=true, entry_computation_layout={(f32[100]{0}, s32[], s32[])->f32[100]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0: f32[100], param_1.3: s32[], param_2: s32[]) -> f32[100] {
  %param_0 = f32[100]{0} parameter(0)
  %param_1.3 = s32[] parameter(1)
  %param_2 = s32[] parameter(2)
  %add.2 = s32[] add(%param_1.3, %param_2), metadata={op_name="jit(_add_scala…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jadball
Comment options

@jakevdp
Comment options

Answer selected by jadball
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants