Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ mod sub;
pub mod type_variable;
mod undo_log;

use crate::infer::canonical::OriginalQueryValues;
pub use rustc_middle::infer::unify_key;

#[must_use]
Expand Down Expand Up @@ -695,14 +694,19 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
) -> bool {
// Reject any attempt to unify two unevaluated constants that contain inference
// variables, since inference variables in queries lead to ICEs.
if a.substs.has_infer_types_or_consts() || b.substs.has_infer_types_or_consts() {
debug!("a or b contain infer vars in its substs -> cannot unify");
if a.substs.has_infer_types_or_consts()
|| b.substs.has_infer_types_or_consts()
|| param_env.has_infer_types_or_consts()
{
debug!("a or b or param_env contain infer vars in its substs -> cannot unify");
return false;
}

let canonical = self.canonicalize_query((a, b), &mut OriginalQueryValues::default());
let erased_args = self.tcx.erase_regions((a, b));
let erased_param_env = self.tcx.erase_regions(param_env);
debug!("after erase_regions args: {:?}, param_env: {:?}", erased_args, param_env);

self.tcx.try_unify_abstract_consts(param_env.and(canonical.value))
self.tcx.try_unify_abstract_consts(erased_param_env.and(erased_args))
}

pub fn is_in_snapshot(&self) -> bool {
Expand Down Expand Up @@ -1619,9 +1623,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
// variables
if substs.has_infer_types_or_consts() {
debug!("substs have infer types or consts: {:?}", substs);
if substs.has_infer_types_or_consts() {
return Err(ErrorHandled::TooGeneric);
}
return Err(ErrorHandled::TooGeneric);
}

let param_env_erased = self.tcx.erase_regions(param_env);
Expand Down
33 changes: 12 additions & 21 deletions compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ fn satisfied_from_param_env<'tcx>(
match pred.kind().skip_binder() {
ty::PredicateKind::ConstEvaluatable(uv) => {
if let Some(b_ct) = AbstractConst::new(tcx, uv)? {
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);

// Try to unify with each subtree in the AbstractConst to allow for
// `N + 1` being const evaluatable even if theres only a `ConstEvaluatable`
// predicate for `(N + 1) * 2`
let result = walk_abstract_const(tcx, b_ct, |b_ct| {
match try_unify(tcx, ct, b_ct, param_env) {
match const_unify_ctxt.try_unify(ct, b_ct) {
true => ControlFlow::BREAK,
false => ControlFlow::CONTINUE,
}
Expand Down Expand Up @@ -569,18 +571,6 @@ pub(super) fn thir_abstract_const<'tcx>(
}
}

/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(tcx), level = "debug")]
pub(super) fn try_unify<'tcx>(
tcx: TyCtxt<'tcx>,
a: AbstractConst<'tcx>,
b: AbstractConst<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);
const_unify_ctxt.try_unify_inner(a, b)
}

pub(super) fn try_unify_abstract_consts<'tcx>(
tcx: TyCtxt<'tcx>,
(a, b): (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>),
Expand All @@ -589,7 +579,8 @@ pub(super) fn try_unify_abstract_consts<'tcx>(
(|| {
if let Some(a) = AbstractConst::new(tcx, a)? {
if let Some(b) = AbstractConst::new(tcx, b)? {
return Ok(try_unify(tcx, a, b, param_env));
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);
return Ok(const_unify_ctxt.try_unify(a, b));
}
}

Expand Down Expand Up @@ -666,7 +657,7 @@ impl<'tcx> ConstUnifyCtxt<'tcx> {

/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(self), level = "debug")]
fn try_unify_inner(&self, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool {
fn try_unify(&self, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool {
let a = if let Some(a) = self.try_replace_substs_in_root(a) {
a
} else {
Expand Down Expand Up @@ -723,23 +714,23 @@ impl<'tcx> ConstUnifyCtxt<'tcx> {
}
}
(Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => {
self.try_unify_inner(a.subtree(al), b.subtree(bl))
&& self.try_unify_inner(a.subtree(ar), b.subtree(br))
self.try_unify(a.subtree(al), b.subtree(bl))
&& self.try_unify(a.subtree(ar), b.subtree(br))
}
(Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => {
self.try_unify_inner(a.subtree(av), b.subtree(bv))
self.try_unify(a.subtree(av), b.subtree(bv))
}
(Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args))
if a_args.len() == b_args.len() =>
{
self.try_unify_inner(a.subtree(a_f), b.subtree(b_f))
self.try_unify(a.subtree(a_f), b.subtree(b_f))
&& iter::zip(a_args, b_args)
.all(|(&an, &bn)| self.try_unify_inner(a.subtree(an), b.subtree(bn)))
.all(|(&an, &bn)| self.try_unify(a.subtree(an), b.subtree(bn)))
}
(Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty))
if (a_ty == b_ty) && (a_kind == b_kind) =>
{
self.try_unify_inner(a.subtree(a_operand), b.subtree(b_operand))
self.try_unify(a.subtree(a_operand), b.subtree(b_operand))
}
// use this over `_ => false` to make adding variants to `Node` less error prone
(Node::Cast(..), _)
Expand Down