Skip to content

Commit

Permalink
Auto merge of rust-lang#131911 - lcnr:probe-no-more-leak-2, r=compile…
Browse files Browse the repository at this point in the history
…r-errors

refactor fudge_inference, handle effect vars

this makes it easier to use fudging outside of `fudge_inference_if_ok`, which is likely necessary to handle inference variable leaks on rollback.

We now also uses exhaustive matches where possible and improve the code to handle effect vars.

r? `@compiler-errors` `@BoxyUwU`
  • Loading branch information
bors committed Oct 20, 2024
2 parents b596184 + d836d35 commit 54791ef
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 124 deletions.
11 changes: 8 additions & 3 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,13 @@ impl<'tcx> InferCtxt<'tcx> {
ty::Const::new_var(self.tcx, vid)
}

fn next_effect_var(&self) -> ty::Const<'tcx> {
let effect_vid =
self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;

ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid))
}

pub fn next_int_var(&self) -> Ty<'tcx> {
let next_int_var_id =
self.inner.borrow_mut().int_unification_table().new_key(ty::IntVarValue::Unknown);
Expand Down Expand Up @@ -1001,15 +1008,13 @@ impl<'tcx> InferCtxt<'tcx> {
}

pub fn var_for_effect(&self, param: &ty::GenericParamDef) -> GenericArg<'tcx> {
let effect_vid =
self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;
let ty = self
.tcx
.type_of(param.def_id)
.no_bound_vars()
.expect("const parameter types cannot be generic");
debug_assert_eq!(self.tcx.types.bool, ty);
ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid)).into()
self.next_effect_var().into()
}

/// Given a set of generics defined on a type or impl, returns the generic parameters mapping
Expand Down
267 changes: 146 additions & 121 deletions compiler/rustc_infer/src/infer/snapshot/fudge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use rustc_data_structures::{snapshot_vec as sv, unify as ut};
use rustc_middle::infer::unify_key::{ConstVariableValue, ConstVidKey};
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
use rustc_type_ir::EffectVid;
use rustc_type_ir::visit::TypeVisitableExt;
use tracing::instrument;
use ut::UnifyKey;

use super::VariableLengths;
use crate::infer::type_variable::TypeVariableOrigin;
use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};

Expand Down Expand Up @@ -40,26 +43,7 @@ fn const_vars_since_snapshot<'tcx>(
)
}

struct VariableLengths {
type_var_len: usize,
const_var_len: usize,
int_var_len: usize,
float_var_len: usize,
region_constraints_len: usize,
}

impl<'tcx> InferCtxt<'tcx> {
fn variable_lengths(&self) -> VariableLengths {
let mut inner = self.inner.borrow_mut();
VariableLengths {
type_var_len: inner.type_variables().num_vars(),
const_var_len: inner.const_unification_table().len(),
int_var_len: inner.int_unification_table().len(),
float_var_len: inner.float_unification_table().len(),
region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
}
}

/// This rather funky routine is used while processing expected
/// types. What happens here is that we want to propagate a
/// coercion through the return type of a fn to its
Expand Down Expand Up @@ -106,78 +90,94 @@ impl<'tcx> InferCtxt<'tcx> {
T: TypeFoldable<TyCtxt<'tcx>>,
{
let variable_lengths = self.variable_lengths();
let (mut fudger, value) = self.probe(|_| {
match f() {
Ok(value) => {
let value = self.resolve_vars_if_possible(value);

// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit `probe()`, those are
// going to be popped, so we will have to
// eliminate any references to them.

let mut inner = self.inner.borrow_mut();
let type_vars =
inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
let int_vars = vars_since_snapshot(
&inner.int_unification_table(),
variable_lengths.int_var_len,
);
let float_vars = vars_since_snapshot(
&inner.float_unification_table(),
variable_lengths.float_var_len,
);
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(variable_lengths.region_constraints_len);
let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
variable_lengths.const_var_len,
);

let fudger = InferenceFudger {
infcx: self,
type_vars,
int_vars,
float_vars,
region_vars,
const_vars,
};

Ok((fudger, value))
}
Err(e) => Err(e),
}
let (snapshot_vars, value) = self.probe(|_| {
let value = f()?;
// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit `probe()`, those are
// going to be popped, so we will have to
// eliminate any references to them.
let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
})?;

// At this point, we need to replace any of the now-popped
// type/region variables that appear in `value` with a fresh
// variable of the appropriate kind. We can't do this during
// the probe because they would just get popped then too. =)
Ok(self.fudge_inference(snapshot_vars, value))
}

fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
&self,
snapshot_vars: SnapshotVarData,
value: T,
) -> T {
// Micro-optimization: if no variables have been created, then
// `value` can't refer to any of them. =) So we can just return it.
if fudger.type_vars.0.is_empty()
&& fudger.int_vars.is_empty()
&& fudger.float_vars.is_empty()
&& fudger.region_vars.0.is_empty()
&& fudger.const_vars.0.is_empty()
{
Ok(value)
if snapshot_vars.is_empty() {
value
} else {
Ok(value.fold_with(&mut fudger))
value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
}
}
}

struct InferenceFudger<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
struct SnapshotVarData {
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
int_vars: Range<IntVid>,
float_vars: Range<FloatVid>,
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
effect_vars: Range<EffectVid>,
}

impl SnapshotVarData {
fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
let mut inner = infcx.inner.borrow_mut();
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
let int_vars =
vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
let float_vars =
vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);

let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
vars_pre_snapshot.const_var_len,
);
let effect_vars = vars_since_snapshot(
&inner.effect_unification_table(),
vars_pre_snapshot.effect_var_len,
);
let effect_vars = effect_vars.start.vid..effect_vars.end.vid;

SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars, effect_vars }
}

fn is_empty(&self) -> bool {
let SnapshotVarData {
region_vars,
type_vars,
int_vars,
float_vars,
const_vars,
effect_vars,
} = self;
region_vars.0.is_empty()
&& type_vars.0.is_empty()
&& int_vars.is_empty()
&& float_vars.is_empty()
&& const_vars.0.is_empty()
&& effect_vars.is_empty()
}
}

struct InferenceFudger<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
snapshot_vars: SnapshotVarData,
}

impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
Expand All @@ -186,68 +186,93 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
match *ty.kind() {
ty::Infer(ty::InferTy::TyVar(vid)) => {
if self.type_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.as_usize() - self.type_vars.0.start.as_usize();
let origin = self.type_vars.1[idx];
self.infcx.next_ty_var_with_origin(origin)
} else {
// This variable was created before the
// "fudging". Since we refresh all type
// variables to their binding anyhow, we know
// that it is unbound, so we can just return
// it.
debug_assert!(
self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
);
ty
if let &ty::Infer(infer_ty) = ty.kind() {
match infer_ty {
ty::TyVar(vid) => {
if self.snapshot_vars.type_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
let origin = self.snapshot_vars.type_vars.1[idx];
self.infcx.next_ty_var_with_origin(origin)
} else {
// This variable was created before the
// "fudging". Since we refresh all type
// variables to their binding anyhow, we know
// that it is unbound, so we can just return
// it.
debug_assert!(
self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
);
ty
}
}
}
ty::Infer(ty::InferTy::IntVar(vid)) => {
if self.int_vars.contains(&vid) {
self.infcx.next_int_var()
} else {
ty
ty::IntVar(vid) => {
if self.snapshot_vars.int_vars.contains(&vid) {
self.infcx.next_int_var()
} else {
ty
}
}
}
ty::Infer(ty::InferTy::FloatVar(vid)) => {
if self.float_vars.contains(&vid) {
self.infcx.next_float_var()
} else {
ty
ty::FloatVar(vid) => {
if self.snapshot_vars.float_vars.contains(&vid) {
self.infcx.next_float_var()
} else {
ty
}
}
ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
unreachable!("unexpected fresh infcx var")
}
}
_ => ty.super_fold_with(self),
} else if ty.has_infer() {
ty.super_fold_with(self)
} else {
ty
}
}

fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
if let ty::ReVar(vid) = *r
&& self.region_vars.0.contains(&vid)
{
let idx = vid.index() - self.region_vars.0.start.index();
let origin = self.region_vars.1[idx];
return self.infcx.next_region_var(origin);
if let ty::ReVar(vid) = r.kind() {
if self.snapshot_vars.region_vars.0.contains(&vid) {
let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
let origin = self.snapshot_vars.region_vars.1[idx];
self.infcx.next_region_var(origin)
} else {
r
}
} else {
r
}
r
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
if let ty::ConstKind::Infer(ty::InferConst::Var(vid)) = ct.kind() {
if self.const_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.index() - self.const_vars.0.start.index();
let origin = self.const_vars.1[idx];
self.infcx.next_const_var_with_origin(origin)
} else {
ct
if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
match infer_ct {
ty::InferConst::Var(vid) => {
if self.snapshot_vars.const_vars.0.contains(&vid) {
let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
let origin = self.snapshot_vars.const_vars.1[idx];
self.infcx.next_const_var_with_origin(origin)
} else {
ct
}
}
ty::InferConst::EffectVar(vid) => {
if self.snapshot_vars.effect_vars.contains(&vid) {
self.infcx.next_effect_var()
} else {
ct
}
}
ty::InferConst::Fresh(_) => {
unreachable!("unexpected fresh infcx var")
}
}
} else {
} else if ct.has_infer() {
ct.super_fold_with(self)
} else {
ct
}
}
}
21 changes: 21 additions & 0 deletions compiler/rustc_infer/src/infer/snapshot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,28 @@ pub struct CombinedSnapshot<'tcx> {
universe: ty::UniverseIndex,
}

struct VariableLengths {
region_constraints_len: usize,
type_var_len: usize,
int_var_len: usize,
float_var_len: usize,
const_var_len: usize,
effect_var_len: usize,
}

impl<'tcx> InferCtxt<'tcx> {
fn variable_lengths(&self) -> VariableLengths {
let mut inner = self.inner.borrow_mut();
VariableLengths {
region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
type_var_len: inner.type_variables().num_vars(),
int_var_len: inner.int_unification_table().len(),
float_var_len: inner.float_unification_table().len(),
const_var_len: inner.const_unification_table().len(),
effect_var_len: inner.effect_unification_table().len(),
}
}

pub fn in_snapshot(&self) -> bool {
UndoLogs::<UndoLog<'tcx>>::in_snapshot(&self.inner.borrow_mut().undo_log)
}
Expand Down

0 comments on commit 54791ef

Please sign in to comment.