Skip to content

Commit 7440dbd

Browse files
committed
call specialized method instance when encountering unspecialized sparams
In some instances, the preferred compilation signature will require sparams to be provided at runtime. When we build the cache around these, we need to make sure the method instance we are calling has those values computed for the current signature, and not use the widened signature. But we can still compile for the widened signature, we just need to make sure we create a cache entry for every narrower call signature. Fix #47476
1 parent d1b9c38 commit 7440dbd

File tree

3 files changed

+82
-17
lines changed

3 files changed

+82
-17
lines changed

src/gf.c

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -826,15 +826,6 @@ static void jl_compilation_sig(
826826
jl_svecset(limited, i, lastdeclt);
827827
}
828828
*newparams = limited;
829-
// now there is a problem: the widened signature is more
830-
// general than just the given arguments, so it might conflict
831-
// with another definition that doesn't have cache instances yet.
832-
// to fix this, we insert guard cache entries for all intersections
833-
// of this signature and definitions. those guard entries will
834-
// supersede this one in conflicted cases, alerting us that there
835-
// should actually be a cache miss.
836-
// TODO: the above analysis assumes that there will never
837-
// be a call attempted that should throw a no-method error
838829
JL_GC_POP();
839830
}
840831
}
@@ -1078,18 +1069,35 @@ static jl_method_instance_t *cache_method(
10781069
jl_svec_t *newparams = NULL;
10791070
JL_GC_PUSH5(&temp, &temp2, &temp3, &newmeth, &newparams);
10801071

1072+
// Consider if we can cache with the preferred compile signature
1073+
// so that we can minimize the number of required cache entries.
10811074
int cache_with_orig = 1;
10821075
jl_tupletype_t *compilationsig = tt;
10831076
jl_methtable_t *kwmt = mt == jl_kwcall_mt ? jl_kwmethod_table_for(definition->sig) : mt;
10841077
intptr_t nspec = (kwmt == NULL || kwmt == jl_type_type_mt || kwmt == jl_nonfunction_mt || kwmt == jl_kwcall_mt ? definition->nargs + 1 : kwmt->max_args + 2 + 2 * (mt == jl_kwcall_mt));
10851078
jl_compilation_sig(tt, sparams, definition, nspec, &newparams);
10861079
if (newparams) {
1087-
compilationsig = jl_apply_tuple_type(newparams);
1088-
temp2 = (jl_value_t*)compilationsig;
1089-
// In most cases `!jl_isa_compileable_sig(tt, definition))`,
1080+
temp2 = (jl_value_t*)jl_apply_tuple_type(newparams);
1081+
// Now there may be a problem: the widened signature is more general
1082+
// than just the given arguments, so it might conflict with another
1083+
// definition that does not have cache instances yet. To fix this, we
1084+
// may insert guard cache entries for all intersections of this
1085+
// signature and definitions. Those guard entries will supersede this
1086+
// one in conflicted cases, alerting us that there should actually be a
1087+
// cache miss. Alternatively, we may use the original signature in the
1088+
// cache, but use this return for compilation.
1089+
//
1090+
// In most cases `!jl_isa_compileable_sig(tt, definition)`,
10901091
// although for some cases, (notably Varargs)
10911092
// we might choose a replacement type that's preferable but not strictly better
1092-
cache_with_orig = !jl_subtype((jl_value_t*)compilationsig, definition->sig);
1093+
int issubty;
1094+
temp = jl_type_intersection_env_s(temp2, (jl_value_t*)definition->sig, &newparams, &issubty);
1095+
assert(temp != (jl_value_t*)jl_bottom_type); (void)temp;
1096+
if (jl_egal((jl_value_t*)newparams, (jl_value_t*)sparams)) {
1097+
cache_with_orig = !issubty;
1098+
compilationsig = (jl_datatype_t*)temp2;
1099+
}
1100+
newparams = NULL;
10931101
}
10941102
// TODO: maybe assert(jl_isa_compileable_sig(compilationsig, definition));
10951103
newmeth = jl_specializations_get_linfo(definition, (jl_value_t*)compilationsig, sparams);
@@ -1110,6 +1118,8 @@ static jl_method_instance_t *cache_method(
11101118
size_t i, l = jl_array_len(temp);
11111119
for (i = 0; i < l; i++) {
11121120
jl_method_match_t *matc = (jl_method_match_t*)jl_array_ptr_ref(temp, i);
1121+
if (matc->method == definition)
1122+
continue;
11131123
jl_svec_t *env = matc->sparams;
11141124
int k, l;
11151125
for (k = 0, l = jl_svec_len(env); k < l; k++) {
@@ -1128,9 +1138,7 @@ static jl_method_instance_t *cache_method(
11281138
cache_with_orig = 1;
11291139
break;
11301140
}
1131-
if (matc->method != definition) {
1132-
guards++;
1133-
}
1141+
guards++;
11341142
}
11351143
}
11361144
if (!cache_with_orig && guards > 0) {
@@ -2095,11 +2103,35 @@ static void record_precompile_statement(jl_method_instance_t *mi)
20952103
JL_UNLOCK(&precomp_statement_out_lock);
20962104
}
20972105

2106+
jl_method_instance_t *jl_normalize_to_compilable_mi(jl_method_instance_t *mi JL_PROPAGATES_ROOT);
2107+
20982108
jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t world)
20992109
{
2110+
// quick check if we already have a compiled result
21002111
jl_code_instance_t *codeinst = jl_method_compiled(mi, world);
21012112
if (codeinst)
21022113
return codeinst;
2114+
2115+
// if mi has a better (wider) signature for compilation use that instead
2116+
// and just copy it here for caching
2117+
jl_method_instance_t *mi2 = jl_normalize_to_compilable_mi(mi);
2118+
if (mi2 != mi) {
2119+
jl_code_instance_t *codeinst2 = jl_compile_method_internal(mi2, world);
2120+
jl_code_instance_t *codeinst = jl_get_method_inferred(
2121+
mi, codeinst2->rettype,
2122+
codeinst2->min_world, codeinst2->max_world);
2123+
if (jl_atomic_load_relaxed(&codeinst->invoke) == NULL) {
2124+
// once set, don't change invoke-ptr, as that leads to race conditions
2125+
// with the (not) simultaneous updates to invoke and specptr
2126+
codeinst->isspecsig = codeinst2->isspecsig;
2127+
codeinst->rettype_const = codeinst2->rettype_const;
2128+
jl_atomic_store_release(&codeinst->specptr.fptr, jl_atomic_load_relaxed(&codeinst2->specptr.fptr));
2129+
jl_atomic_store_release(&codeinst->invoke, jl_atomic_load_relaxed(&codeinst2->invoke));
2130+
}
2131+
// don't call record_precompile_statement here, since we already compiled it as mi2 which is better
2132+
return codeinst;
2133+
}
2134+
21032135
int compile_option = jl_options.compile_enabled;
21042136
jl_method_t *def = mi->def.method;
21052137
// disabling compilation per-module can override global setting
@@ -2134,6 +2166,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
21342166
}
21352167
}
21362168
}
2169+
21372170
// if that didn't work and compilation is off, try running in the interpreter
21382171
if (compile_option == JL_OPTIONS_COMPILE_OFF ||
21392172
compile_option == JL_OPTIONS_COMPILE_MIN) {
@@ -2254,6 +2287,26 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
22542287
return is_compileable ? (jl_value_t*)tt : jl_nothing;
22552288
}
22562289

2290+
jl_method_instance_t *jl_normalize_to_compilable_mi(jl_method_instance_t *mi JL_PROPAGATES_ROOT)
2291+
{
2292+
jl_method_t *def = mi->def.method;
2293+
if (!jl_is_method(def))
2294+
return mi;
2295+
jl_methtable_t *mt = jl_method_get_table(def);
2296+
if ((jl_value_t*)mt == jl_nothing)
2297+
return mi;
2298+
jl_value_t *compilationsig = jl_normalize_to_compilable_sig(mt, (jl_datatype_t*)mi->specTypes, mi->sparam_vals, def);
2299+
if (compilationsig == jl_nothing || jl_egal(compilationsig, mi->specTypes))
2300+
return mi;
2301+
jl_svec_t *env = NULL;
2302+
JL_GC_PUSH2(&compilationsig, &env);
2303+
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)mi->specTypes, (jl_value_t*)def->sig, &env);
2304+
assert(ti != jl_bottom_type); (void)ti;
2305+
mi = jl_specializations_get_linfo(def, (jl_value_t*)compilationsig, env);
2306+
JL_GC_POP();
2307+
return mi;
2308+
}
2309+
22572310
// return a MethodInstance for a compileable method_match
22582311
jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t world, size_t min_valid, size_t max_valid, int mt_cache)
22592312
{

src/jitlayers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ static jl_callptr_t _jl_compile_codeinst(
267267
// hack to export this pointer value to jl_dump_method_disasm
268268
jl_atomic_store_release(&this_code->specptr.fptr, (void*)getAddressForFunction(decls.specFunctionObject));
269269
}
270-
if (this_code== codeinst)
270+
if (this_code == codeinst)
271271
fptr = addr;
272272
}
273273

test/core.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7873,3 +7873,15 @@ let # https://github.com/JuliaLang/julia/issues/46918
78737873
@test isempty(String(take!(stderr))) # make sure no error has happened
78747874
@test String(take!(stdout)) == "nothing IO IO"
78757875
end
7876+
7877+
# issue #47476
7878+
f47476(::Union{Int, NTuple{N,Int}}...) where {N} = N
7879+
# force it to populate the MethodInstance specializations cache
7880+
# with the correct sparams
7881+
code_typed(f47476, (Vararg{Union{Int, NTuple{2,Int}}},));
7882+
code_typed(f47476, (Int, Vararg{Union{Int, NTuple{2,Int}}},));
7883+
code_typed(f47476, (Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
7884+
code_typed(f47476, (Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
7885+
code_typed(f47476, (Int, Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
7886+
@test f47476(1, 2, 3, 4, 5, 6, (7, 8)) === 2
7887+
@test_throws UndefVarError(:N) f47476(1, 2, 3, 4, 5, 6, 7)

0 commit comments

Comments
 (0)