Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DeepRejectCtxt to quickly reject ParamEnv candidates #128776

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ impl<'tcx> InherentCollect<'tcx> {
}
}

if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) {
if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer)
{
self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id);
} else {
bug!("unexpected self type: {:?}", self_ty);
Expand Down Expand Up @@ -129,7 +130,7 @@ impl<'tcx> InherentCollect<'tcx> {
}
}

if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::AsCandidateKey) {
if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::InstantiateWithInfer) {
self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id);
} else {
bug!("unexpected primitive type: {:?}", ty);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/method/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
}

fn assemble_inherent_candidates_for_incoherent_ty(&mut self, self_ty: Ty<'tcx>) {
let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) else {
let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer) else {
bug!("unexpected incoherent type: {:?}", self_ty)
};
for &impl_def_id in self.tcx.incoherent_impls(simp).into_iter().flatten() {
Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_hir_typeck/src/method/suggest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2234,8 +2234,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let target_ty = self
.autoderef(sugg_span, rcvr_ty)
.find(|(rcvr_ty, _)| {
DeepRejectCtxt::new(self.tcx, TreatParams::ForLookup)
.types_may_unify(*rcvr_ty, impl_ty)
DeepRejectCtxt::relate_rigid_infer(self.tcx).types_may_unify(*rcvr_ty, impl_ty)
})
.map_or(impl_ty, |(ty, _)| ty)
.peel_refs();
Expand Down Expand Up @@ -2497,7 +2496,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
.into_iter()
.any(|info| self.associated_value(info.def_id, item_name).is_some());
let found_assoc = |ty: Ty<'tcx>| {
simplify_type(tcx, ty, TreatParams::AsCandidateKey)
simplify_type(tcx, ty, TreatParams::InstantiateWithInfer)
.and_then(|simp| {
tcx.incoherent_impls(simp)
.into_iter()
Expand Down Expand Up @@ -3962,7 +3961,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// cases where a positive bound implies a negative impl.
(candidates, Vec::new())
} else if let Some(simp_rcvr_ty) =
simplify_type(self.tcx, rcvr_ty, TreatParams::ForLookup)
simplify_type(self.tcx, rcvr_ty, TreatParams::AsRigid)
{
let mut potential_candidates = Vec::new();
let mut explicitly_negative = Vec::new();
Expand All @@ -3980,7 +3979,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
.any(|header| {
let imp = header.trait_ref.instantiate_identity();
let imp_simp =
simplify_type(self.tcx, imp.self_ty(), TreatParams::ForLookup);
simplify_type(self.tcx, imp.self_ty(), TreatParams::AsRigid);
imp_simp.is_some_and(|s| s == simp_rcvr_ty)
})
{
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
let simplified_self_ty = fast_reject::simplify_type(
self.tcx,
trait_ref.self_ty(),
TreatParams::AsCandidateKey,
TreatParams::InstantiateWithInfer,
);
trait_impls
.entry(trait_ref.def_id)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
let simp = ty::fast_reject::simplify_type(
tcx,
self_ty,
ty::fast_reject::TreatParams::ForLookup,
ty::fast_reject::TreatParams::AsRigid,
)
.unwrap();
consider_impls_for_simplified_type(simp);
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_middle/src/ty/fast_reject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ pub use rustc_type_ir::fast_reject::*;

use super::TyCtxt;

pub type DeepRejectCtxt<'tcx> = rustc_type_ir::fast_reject::DeepRejectCtxt<TyCtxt<'tcx>>;
pub type DeepRejectCtxt<
'tcx,
const INSTANTIATE_LHS_WITH_INFER: bool,
const INSTANTIATE_RHS_WITH_INFER: bool,
> = rustc_type_ir::fast_reject::DeepRejectCtxt<
TyCtxt<'tcx>,
INSTANTIATE_LHS_WITH_INFER,
INSTANTIATE_RHS_WITH_INFER,
>;

pub type SimplifiedType = rustc_type_ir::fast_reject::SimplifiedType<DefId>;
12 changes: 7 additions & 5 deletions compiler/rustc_middle/src/ty/trait_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ impl<'tcx> TyCtxt<'tcx> {
// whose outer level is not a parameter or projection. Especially for things like
// `T: Clone` this is incredibly useful as we would otherwise look at all the impls
// of `Clone` for `Option<T>`, `Vec<T>`, `ConcreteType` and so on.
// Note that we're using `TreatParams::ForLookup` to query `non_blanket_impls` while using
// `TreatParams::AsCandidateKey` while actually adding them.
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::ForLookup) {
// Note that we're using `TreatParams::AsRigid` to query `non_blanket_impls` while using
// `TreatParams::InstantiateWithInfer` while actually adding them.
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsRigid) {
if let Some(impls) = impls.non_blanket_impls.get(&simp) {
for &impl_def_id in impls {
f(impl_def_id);
Expand All @@ -190,7 +190,9 @@ impl<'tcx> TyCtxt<'tcx> {
self_ty: Ty<'tcx>,
) -> impl Iterator<Item = DefId> + 'tcx {
let impls = self.trait_impls_of(trait_def_id);
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsCandidateKey) {
if let Some(simp) =
fast_reject::simplify_type(self, self_ty, TreatParams::InstantiateWithInfer)
{
if let Some(impls) = impls.non_blanket_impls.get(&simp) {
return impls.iter().copied();
}
Expand Down Expand Up @@ -239,7 +241,7 @@ pub(super) fn trait_impls_of_provider(tcx: TyCtxt<'_>, trait_id: DefId) -> Trait
let impl_self_ty = tcx.type_of(impl_def_id).instantiate_identity();

if let Some(simplified_self_ty) =
fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::InstantiateWithInfer)
{
impls.non_blanket_impls.entry(simplified_self_ty).or_default().push(impl_def_id);
} else {
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod inherent;
mod opaque_types;
mod weak_types;

use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::fast_reject::DeepRejectCtxt;
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::{self as ty, Interner, NormalizesTo, Upcast as _};
Expand Down Expand Up @@ -106,6 +106,12 @@ where
if let Some(projection_pred) = assumption.as_projection_clause() {
if projection_pred.projection_def_id() == goal.predicate.def_id() {
let cx = ecx.cx();
if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify(
goal.predicate.alias.args,
projection_pred.skip_binder().projection_term.args,
) {
return Err(NoSolution);
}
ecx.probe_trait_candidate(source).enter(|ecx| {
let assumption_projection_pred =
ecx.instantiate_binder_with_infer(projection_pred);
Expand Down Expand Up @@ -144,7 +150,7 @@ where

let goal_trait_ref = goal.predicate.alias.trait_ref(cx);
let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup).args_may_unify(
if !DeepRejectCtxt::relate_rigid_infer(ecx.cx()).args_may_unify(
goal.predicate.alias.trait_ref(cx).args,
impl_trait_ref.skip_binder().args,
) {
Expand Down
11 changes: 9 additions & 2 deletions compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use rustc_ast_ir::Movability;
use rustc_type_ir::data_structures::IndexSet;
use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::fast_reject::DeepRejectCtxt;
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _;
Expand Down Expand Up @@ -47,7 +47,7 @@ where
let cx = ecx.cx();

let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup)
if !DeepRejectCtxt::relate_rigid_infer(ecx.cx())
.args_may_unify(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args)
{
return Err(NoSolution);
Expand Down Expand Up @@ -124,6 +124,13 @@ where
if trait_clause.def_id() == goal.predicate.def_id()
&& trait_clause.polarity() == goal.predicate.polarity
{
if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify(
goal.predicate.trait_ref.args,
trait_clause.skip_binder().trait_ref.args,
) {
return Err(NoSolution);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add this for NormalizesTo goals

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ecx.probe_trait_candidate(source).enter(|ecx| {
let assumption_trait_pred = ecx.instantiate_binder_with_infer(trait_clause);
ecx.eq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::print::{FmtPrinter, Printer};
use rustc_middle::ty::{self, suggest_constraining_type_param, Ty};
use rustc_span::def_id::DefId;
Expand Down Expand Up @@ -316,7 +316,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand All @@ -337,7 +337,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.expected, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.expected, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand All @@ -357,7 +357,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_middle::bug;
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::solve::{CandidateSource, Certainty, Goal};
use rustc_middle::traits::specialization_graph::OverlapMode;
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor};
use rustc_middle::ty::{self, Ty, TyCtxt};
pub use rustc_next_trait_solver::coherence::*;
Expand Down Expand Up @@ -94,7 +94,7 @@ pub fn overlapping_impls(
// Before doing expensive operations like entering an inference context, do
// a quick check via fast_reject to tell if the impl headers could possibly
// unify.
let drcx = DeepRejectCtxt::new(tcx, TreatParams::AsCandidateKey);
let drcx = DeepRejectCtxt::relate_infer_infer(tcx);
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
let may_overlap = match (impl1_ref, impl2_ref) {
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_infer::traits::ObligationCauseCode;
use rustc_middle::traits::select::OverflowError;
pub use rustc_middle::traits::Reveal;
use rustc_middle::traits::{BuiltinImplSource, ImplSource, ImplSourceUserDefinedData};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::visit::{MaxUniverse, TypeVisitable, TypeVisitableExt};
use rustc_middle::ty::{self, Term, Ty, TyCtxt, Upcast};
Expand Down Expand Up @@ -885,6 +886,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
potentially_unnormalized_candidates: bool,
) {
let infcx = selcx.infcx;
let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
for predicate in env_predicates {
let bound_predicate = predicate.kind();
if let ty::ClauseKind::Projection(data) = predicate.kind().skip_binder() {
Expand All @@ -893,6 +895,12 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
continue;
}

if !drcx
.args_may_unify(obligation.predicate.args, data.skip_binder().projection_term.args)
{
continue;
}

let is_match = infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use hir::LangItem;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_hir as hir;
use rustc_infer::traits::{Obligation, ObligationCause, PolyTraitObligation, SelectionError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::{self, ToPolyTraitRef, Ty, TypeVisitableExt};
use rustc_middle::{bug, span_bug};

Expand Down Expand Up @@ -247,11 +247,17 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
.filter(|p| p.def_id() == stack.obligation.predicate.def_id())
.filter(|p| p.polarity() == stack.obligation.predicate.polarity());

let drcx = DeepRejectCtxt::relate_rigid_rigid(self.tcx());
let obligation_args = stack.obligation.predicate.skip_binder().trait_ref.args;
// Keep only those bounds which may apply, and propagate overflow if it occurs.
for bound in bounds {
let bound_trait_ref = bound.map_bound(|t| t.trait_ref);
if !drcx.args_may_unify(obligation_args, bound_trait_ref.skip_binder().args) {
continue;
}
// FIXME(oli-obk): it is suspicious that we are dropping the constness and
// polarity here.
let wc = self.where_clause_may_apply(stack, bound.map_bound(|t| t.trait_ref))?;
let wc = self.where_clause_may_apply(stack, bound_trait_ref)?;
if wc.may_apply() {
candidates.vec.push(ParamCandidate(bound));
}
Expand Down Expand Up @@ -580,7 +586,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
return;
}

let drcx = DeepRejectCtxt::new(self.tcx(), TreatParams::ForLookup);
let drcx = DeepRejectCtxt::relate_rigid_infer(self.tcx());
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
self.tcx().for_each_relevant_impl(
obligation.predicate.def_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl<'tcx> Children {
fn insert_blindly(&mut self, tcx: TyCtxt<'tcx>, impl_def_id: DefId) {
let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder();
if let Some(st) =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer)
{
debug!("insert_blindly: impl_def_id={:?} st={:?}", impl_def_id, st);
self.non_blanket_impls.entry(st).or_default().push(impl_def_id)
Expand All @@ -57,7 +57,7 @@ impl<'tcx> Children {
let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder();
let vec: &mut Vec<DefId>;
if let Some(st) =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer)
{
debug!("remove_existing: impl_def_id={:?} st={:?}", impl_def_id, st);
vec = self.non_blanket_impls.get_mut(&st).unwrap();
Expand Down Expand Up @@ -278,7 +278,7 @@ impl<'tcx> Graph {
let mut parent = trait_def_id;
let mut last_lint = None;
let simplified =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey);
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer);

// Descend the specialization tree, where `parent` is the current parent node.
loop {
Expand Down
Loading
Loading