Skip to content

Commit

Permalink
Rollup merge of rust-lang#127856 - RalfJung:interpret-cast-sanity, r=…
Browse files Browse the repository at this point in the history
…oli-obk

interpret: add sanity check in dyn upcast to double-check what codegen does

For dyn receiver calls, we already have two codepaths: look up the function to call by indexing into the vtable, or alternatively resolve the DefId given the dynamic type of the receiver. With debug assertions enabled, the interpreter does both and compares the results. (Without debug assertions we always use the vtable as it is simpler.)

This PR does the same for dyn trait upcasts. However, for casts *not* using the vtable is the easier thing to do, so now the vtable path is the debug-assertion-only path. In particular, there are cases where the vtable does not contain a pointer for upcasts but instead reuses the old pointer: when the supertrait vtable is a prefix of the larger vtable. We don't want to expose this optimization and detect UB if people do a transmute assuming this optimization, so we cannot in general use the vtable indexing path.

r? `@oli-obk`
  • Loading branch information
matthiaskrgr authored Jul 19, 2024
2 parents 161f5b2 + a7b8081 commit 552aa4c
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 237 deletions.
41 changes: 36 additions & 5 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
(ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => {
let val = self.read_immediate(src)?;
if data_a.principal() == data_b.principal() {
// A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables.
// (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.)
return self.write_immediate(*val, dest);
}
// Take apart the old pointer, and find the dynamic type.
let (old_data, old_vptr) = val.to_scalar_pair();
let old_data = old_data.to_pointer(self)?;
let old_vptr = old_vptr.to_pointer(self)?;
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;

// Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
// our destination trait.
if cfg!(debug_assertions) {
let vptr_entry_idx =
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
if let Some(entry_idx) = vptr_entry_idx {
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
vtable_entries.get(entry_idx)
else {
span_bug!(
self.cur_span(),
"invalid vtable entry index in {} -> {} upcast",
src_pointee_ty,
dest_pointee_ty
);
};
let erased_trait_ref = upcast_trait_ref
.map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r));
assert!(
data_b
.principal()
.is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b))
);
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
// one vtable is a prefix of the other, so we double-check that.
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
};
}

// Get the destination trait vtable and return that.
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
}
Expand Down
30 changes: 30 additions & 0 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use std::cell::Cell;
use std::{fmt, mem};

use either::{Either, Left, Right};
use rustc_infer::infer::at::ToTrace;
use rustc_infer::traits::ObligationCause;
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::{debug, info, info_span, instrument, trace};

use rustc_errors::DiagCtxtHandle;
use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
use rustc_index::IndexVec;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::mir;
use rustc_middle::mir::interpret::{
CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
Expand Down Expand Up @@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}

/// Check if the two things are equal in the current param_env, using an infctx to get proper
/// equality checks.
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
where
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
{
// Fast path: compare directly.
if a == b {
return true;
}
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let a = ocx.normalize(&cause, self.param_env, a);
let b = ocx.normalize(&cause, self.param_env, b);
if ocx.eq(&cause, self.param_env, a, b).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return true;
}
}
return false;
}

/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
/// and is primarily intended for the panic machinery.
Expand Down
17 changes: 6 additions & 11 deletions compiler/rustc_const_eval/src/interpret/terminator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;

use either::Either;
use rustc_middle::ty::TyCtxt;
use tracing::trace;

use rustc_middle::{
Expand Down Expand Up @@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
};

// Obtain the underlying trait we are working on, and the adjusted receiver argument.
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
receiver_place.layout.ty.kind()
{
let recv = self.unpack_dyn_star(&receiver_place, data)?;
Expand Down Expand Up @@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
};

// Now determine the actual method to call. We can do that in two different ways and
// compare them to ensure everything fits.
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
};
// Now determine the actual method to call. Usually we use the easy way of just
// looking up the method at index `idx`.
let vtable_entries = self.vtable_entries(trait_, dyn_ty);
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
};
trace!("Virtual call dispatches to {fn_inst:#?}");
// We can also do the lookup based on `def_id` and `dyn_ty`, and check that that
// produces the same result.
if cfg!(debug_assertions) {
let tcx = *self.tcx;

Expand Down
48 changes: 23 additions & 25 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
use rustc_target::abi::{Align, Size};
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::trace;

use super::util::ensure_monomorphic_enough;
Expand Down Expand Up @@ -47,35 +44,36 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok((layout.size, layout.align.abi))
}

pub(super) fn vtable_entries(
&self,
trait_: Option<ty::PolyExistentialTraitRef<'tcx>>,
dyn_ty: Ty<'tcx>,
) -> &'tcx [VtblEntry<'tcx>] {
if let Some(trait_) = trait_ {
let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
}
}

/// Check that the given vtable trait is valid for a pointer/reference/place with the given
/// expected trait type.
pub(super) fn check_vtable_for_type(
&self,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx> {
// Fast path: if they are equal, it's all fine.
if expected_trait.principal() == vtable_trait {
return Ok(());
}
if let (Some(expected_trait), Some(vtable_trait)) =
(expected_trait.principal(), vtable_trait)
{
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return Ok(());
}
}
let eq = match (expected_trait.principal(), vtable_trait) {
(Some(a), Some(b)) => self.eq_in_param_env(a, b),
(None, None) => true,
_ => false,
};
if !eq {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
Ok(())
}

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
Expand Down
20 changes: 12 additions & 8 deletions compiler/rustc_trait_selection/src/traits/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRe
}

/// Given a `dyn Subtrait` and `dyn Supertrait` trait object, find the slot of
/// // the trait vptr in the subtrait's vtable.
/// the trait vptr in the subtrait's vtable.
///
/// A return value of `None` means that the original vtable can be reused.
pub(crate) fn supertrait_vtable_slot<'tcx>(
tcx: TyCtxt<'tcx>,
key: (
Expand All @@ -373,20 +375,22 @@ pub(crate) fn supertrait_vtable_slot<'tcx>(
),
) -> Option<usize> {
debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param());

let (source, target) = key;
let ty::Dynamic(source, _, _) = *source.kind() else {

// If the target principal is `None`, we can just return `None`.
let ty::Dynamic(target, _, _) = *target.kind() else {
bug!();
};
let source_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), source.principal().unwrap())
let target_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), target.principal()?)
.with_self_ty(tcx, tcx.types.trait_object_dummy_self);

let ty::Dynamic(target, _, _) = *target.kind() else {
// Given that we have a target principal, it is a bug for there not to be a source principal.
let ty::Dynamic(source, _, _) = *source.kind() else {
bug!();
};
let target_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), target.principal().unwrap())
let source_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), source.principal().unwrap())
.with_self_ty(tcx, tcx.types.trait_object_dummy_self);

let vtable_segment_callback = {
Expand Down
10 changes: 6 additions & 4 deletions src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ impl Baz for i32 {
}

fn main() {
let baz: &dyn Baz = &1;
let baz_fake: *const dyn Bar = unsafe { std::mem::transmute(baz) };
let _err = baz_fake as *const dyn Foo;
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
unsafe {
let baz: &dyn Baz = &1;
let baz_fake: *const dyn Bar = std::mem::transmute(baz);
let _err = baz_fake as *const dyn Foo;
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
}
}
4 changes: 2 additions & 2 deletions src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.stderr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error: Undefined Behavior: using vtable for trait `Baz` but trait `Bar` was expected
--> $DIR/dyn-upcast-trait-mismatch.rs:LL:CC
|
LL | let _err = baz_fake as *const dyn Foo;
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
LL | let _err = baz_fake as *const dyn Foo;
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
|
= help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
= help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information
Expand Down
42 changes: 16 additions & 26 deletions tests/ui/consts/const-eval/raw-bytes.32bit.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -436,30 +436,20 @@ LL | const TRAIT_OBJ_CONTENT_INVALID: &dyn Trait = unsafe { mem::transmute::<_,
╾ALLOC_ID╼ ╾ALLOC_ID╼ │ ╾──╼╾──╼
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:196:1
error[E0080]: evaluation of constant value failed
--> $DIR/raw-bytes.rs:196:62
|
LL | const RAW_TRAIT_OBJ_VTABLE_NULL: *const dyn Trait = unsafe { mem::transmute((&92u8, 0usize)) };
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: encountered null pointer, but expected a vtable pointer
|
= note: The rules on what exactly is undefined behavior aren't clear, so this check might be overzealous. Please open an issue on the rustc repository if you believe it should not be considered undefined behavior.
= note: the raw bytes of the constant (size: 8, align: 4) {
╾ALLOC_ID╼ 00 00 00 00 │ ╾──╼....
}
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ out-of-bounds pointer use: null pointer is a dangling pointer (it has no provenance)

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:198:1
error[E0080]: evaluation of constant value failed
--> $DIR/raw-bytes.rs:199:65
|
LL | const RAW_TRAIT_OBJ_VTABLE_INVALID: *const dyn Trait = unsafe { mem::transmute((&92u8, &3u64)) };
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: encountered ALLOC27<imm>, but expected a vtable pointer
|
= note: The rules on what exactly is undefined behavior aren't clear, so this check might be overzealous. Please open an issue on the rustc repository if you believe it should not be considered undefined behavior.
= note: the raw bytes of the constant (size: 8, align: 4) {
╾ALLOC_ID╼ ╾ALLOC_ID╼ │ ╾──╼╾──╼
}
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ using ALLOC32 as vtable pointer but it does not point to a vtable

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:202:1
--> $DIR/raw-bytes.rs:204:1
|
LL | const _: &[!; 1] = unsafe { &*(1_usize as *const [!; 1]) };
| ^^^^^^^^^^^^^^^^ constructing invalid value: encountered a reference pointing to uninhabited type [!; 1]
Expand All @@ -470,7 +460,7 @@ LL | const _: &[!; 1] = unsafe { &*(1_usize as *const [!; 1]) };
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:203:1
--> $DIR/raw-bytes.rs:205:1
|
LL | const _: &[!] = unsafe { &*(1_usize as *const [!; 1]) };
| ^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered a value of the never type `!`
Expand All @@ -481,7 +471,7 @@ LL | const _: &[!] = unsafe { &*(1_usize as *const [!; 1]) };
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:204:1
--> $DIR/raw-bytes.rs:206:1
|
LL | const _: &[!] = unsafe { &*(1_usize as *const [!; 42]) };
| ^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered a value of the never type `!`
Expand All @@ -492,7 +482,7 @@ LL | const _: &[!] = unsafe { &*(1_usize as *const [!; 42]) };
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:208:1
--> $DIR/raw-bytes.rs:210:1
|
LL | pub static S4: &[u8] = unsafe { from_raw_parts((&D1) as *const _ as _, 1) };
| ^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered uninitialized memory, but expected an integer
Expand All @@ -503,7 +493,7 @@ LL | pub static S4: &[u8] = unsafe { from_raw_parts((&D1) as *const _ as _, 1) }
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:211:1
--> $DIR/raw-bytes.rs:213:1
|
LL | pub static S5: &[u8] = unsafe { from_raw_parts((&D3) as *const _ as _, mem::size_of::<&u32>()) };
| ^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered a pointer, but expected an integer
Expand All @@ -516,7 +506,7 @@ LL | pub static S5: &[u8] = unsafe { from_raw_parts((&D3) as *const _ as _, mem:
= help: the absolute address of a pointer is not known at compile-time, so such operations are not supported

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:214:1
--> $DIR/raw-bytes.rs:216:1
|
LL | pub static S6: &[bool] = unsafe { from_raw_parts((&D0) as *const _ as _, 4) };
| ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered 0x11, but expected a boolean
Expand All @@ -527,7 +517,7 @@ LL | pub static S6: &[bool] = unsafe { from_raw_parts((&D0) as *const _ as _, 4)
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:218:1
--> $DIR/raw-bytes.rs:220:1
|
LL | pub static S7: &[u16] = unsafe {
| ^^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[1]: encountered uninitialized memory, but expected an integer
Expand All @@ -538,7 +528,7 @@ LL | pub static S7: &[u16] = unsafe {
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:225:1
--> $DIR/raw-bytes.rs:227:1
|
LL | pub static R4: &[u8] = unsafe {
| ^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered uninitialized memory, but expected an integer
Expand All @@ -549,7 +539,7 @@ LL | pub static R4: &[u8] = unsafe {
}

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:230:1
--> $DIR/raw-bytes.rs:232:1
|
LL | pub static R5: &[u8] = unsafe {
| ^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered a pointer, but expected an integer
Expand All @@ -562,7 +552,7 @@ LL | pub static R5: &[u8] = unsafe {
= help: the absolute address of a pointer is not known at compile-time, so such operations are not supported

error[E0080]: it is undefined behavior to use this value
--> $DIR/raw-bytes.rs:235:1
--> $DIR/raw-bytes.rs:237:1
|
LL | pub static R6: &[bool] = unsafe {
| ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value at .<deref>[0]: encountered 0x11, but expected a boolean
Expand Down
Loading

0 comments on commit 552aa4c

Please sign in to comment.