Skip to content

Commit

Permalink
Auto merge of #116551 - RalfJung:nondet-nan, r=oli-obk
Browse files Browse the repository at this point in the history
miri: make NaN generation non-deterministic

This implements the [LLVM semantics for NaN generation](https://llvm.org/docs/LangRef.html#behavior-of-floating-point-nan-values). I will soon submit an RFC to make this also officially the Rust semantics, but it has been our de-facto semantics for a long time so there's no reason Miri has to wait for that RFC. This PR just better aligns Miri with codegen.

This PR does that just for the operations that have MIR primitives; a future PR will adjust the intrinsics.
  • Loading branch information
bors committed Oct 10, 2023
2 parents 5c37696 + 08deb0d commit 061c330
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 29 deletions.
23 changes: 21 additions & 2 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
F: Float + Into<Scalar<M::Provenance>> + FloatConvert<Single> + FloatConvert<Double>,
{
use rustc_type_ir::sty::TyKind::*;

fn adjust_nan<
'mir,
'tcx: 'mir,
M: Machine<'mir, 'tcx>,
F1: rustc_apfloat::Float + FloatConvert<F2>,
F2: rustc_apfloat::Float,
>(
ecx: &InterpCx<'mir, 'tcx, M>,
f1: F1,
f2: F2,
) -> F2 {
if f2.is_nan() { M::generate_nan(ecx, &[f1]) } else { f2 }
}

match *dest_ty.kind() {
// float -> uint
Uint(t) => {
Expand All @@ -330,9 +345,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Scalar::from_int(v, size)
}
// float -> f32
Float(FloatTy::F32) => Scalar::from_f32(f.convert(&mut false).value),
Float(FloatTy::F32) => {
Scalar::from_f32(adjust_nan(self, f, f.convert(&mut false).value))
}
// float -> f64
Float(FloatTy::F64) => Scalar::from_f64(f.convert(&mut false).value),
Float(FloatTy::F64) => {
Scalar::from_f64(adjust_nan(self, f, f.convert(&mut false).value))
}
// That's it.
_ => span_bug!(self.cur_span(), "invalid float to {} cast", dest_ty),
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
b: &ImmTy<'tcx, M::Provenance>,
dest: &PlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx> {
assert_eq!(a.layout.ty, b.layout.ty);
assert!(matches!(a.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));

// Performs an exact division, resulting in undefined behavior where
// `x % y != 0` or `y == 0` or `x == T::MIN && y == -1`.
// First, check x % y != 0 (or if that computation overflows).
Expand All @@ -522,7 +525,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
l: &ImmTy<'tcx, M::Provenance>,
r: &ImmTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, Scalar<M::Provenance>> {
assert_eq!(l.layout.ty, r.layout.ty);
assert!(matches!(l.layout.ty.kind(), ty::Int(..) | ty::Uint(..)));
assert!(matches!(mir_op, BinOp::Add | BinOp::Sub));

let (val, overflowed) = self.overflowing_binary_op(mir_op, l, r)?;
Ok(if overflowed {
let size = l.layout.size;
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_const_eval/src/interpret/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::borrow::{Borrow, Cow};
use std::fmt::Debug;
use std::hash::Hash;

use rustc_apfloat::{Float, FloatConvert};
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_middle::mir;
use rustc_middle::ty::layout::TyAndLayout;
Expand Down Expand Up @@ -240,6 +241,16 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
right: &ImmTy<'tcx, Self::Provenance>,
) -> InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)>;

/// Generate the NaN returned by a float operation, given the list of inputs.
/// (This is all inputs, not just NaN inputs!)
fn generate_nan<F1: Float + FloatConvert<F2>, F2: Float>(
_ecx: &InterpCx<'mir, 'tcx, Self>,
_inputs: &[F1],
) -> F2 {
// By default we always return the preferred NaN.
F2::NAN
}

/// Called before writing the specified `local` of the `frame`.
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.
Expand Down
20 changes: 13 additions & 7 deletions compiler/rustc_const_eval/src/interpret/operator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_apfloat::Float;
use rustc_apfloat::{Float, FloatConvert};
use rustc_middle::mir;
use rustc_middle::mir::interpret::{InterpResult, Scalar};
use rustc_middle::ty::layout::TyAndLayout;
Expand Down Expand Up @@ -104,7 +104,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
(ImmTy::from_bool(res, *self.tcx), false)
}

fn binary_float_op<F: Float + Into<Scalar<M::Provenance>>>(
fn binary_float_op<F: Float + FloatConvert<F> + Into<Scalar<M::Provenance>>>(
&self,
bin_op: mir::BinOp,
layout: TyAndLayout<'tcx>,
Expand All @@ -113,18 +113,23 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
) -> (ImmTy<'tcx, M::Provenance>, bool) {
use rustc_middle::mir::BinOp::*;

// Performs appropriate non-deterministic adjustments of NaN results.
let adjust_nan = |f: F| -> F {
if f.is_nan() { M::generate_nan(self, &[l, r]) } else { f }
};

let val = match bin_op {
Eq => ImmTy::from_bool(l == r, *self.tcx),
Ne => ImmTy::from_bool(l != r, *self.tcx),
Lt => ImmTy::from_bool(l < r, *self.tcx),
Le => ImmTy::from_bool(l <= r, *self.tcx),
Gt => ImmTy::from_bool(l > r, *self.tcx),
Ge => ImmTy::from_bool(l >= r, *self.tcx),
Add => ImmTy::from_scalar((l + r).value.into(), layout),
Sub => ImmTy::from_scalar((l - r).value.into(), layout),
Mul => ImmTy::from_scalar((l * r).value.into(), layout),
Div => ImmTy::from_scalar((l / r).value.into(), layout),
Rem => ImmTy::from_scalar((l % r).value.into(), layout),
Add => ImmTy::from_scalar(adjust_nan((l + r).value).into(), layout),
Sub => ImmTy::from_scalar(adjust_nan((l - r).value).into(), layout),
Mul => ImmTy::from_scalar(adjust_nan((l * r).value).into(), layout),
Div => ImmTy::from_scalar(adjust_nan((l / r).value).into(), layout),
Rem => ImmTy::from_scalar(adjust_nan((l % r).value).into(), layout),
_ => span_bug!(self.cur_span(), "invalid float op: `{:?}`", bin_op),
};
(val, false)
Expand Down Expand Up @@ -456,6 +461,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Ok((ImmTy::from_bool(res, *self.tcx), false))
}
ty::Float(fty) => {
// No NaN adjustment here, `-` is a bitwise operation!
let res = match (un_op, fty) {
(Neg, FloatTy::F32) => Scalar::from_f32(-val.to_f32()?),
(Neg, FloatTy::F64) => Scalar::from_f64(-val.to_f64()?),
Expand Down
8 changes: 8 additions & 0 deletions src/tools/miri/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,14 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
ecx.binary_ptr_op(bin_op, left, right)
}

#[inline(always)]
fn generate_nan<F1: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F2>, F2: rustc_apfloat::Float>(
ecx: &InterpCx<'mir, 'tcx, Self>,
inputs: &[F1],
) -> F2 {
ecx.generate_nan(inputs)
}

fn thread_local_static_base_pointer(
ecx: &mut MiriInterpCx<'mir, 'tcx>,
def_id: DefId,
Expand Down
78 changes: 58 additions & 20 deletions src/tools/miri/src/operator.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use std::iter;

use log::trace;

use rand::{seq::IteratorRandom, Rng};
use rustc_apfloat::{Float, FloatConvert};
use rustc_middle::mir;
use rustc_target::abi::Size;

use crate::*;

pub trait EvalContextExt<'tcx> {
fn binary_ptr_op(
&self,
bin_op: mir::BinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)>;
}

impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn binary_ptr_op(
&self,
bin_op: mir::BinOp,
Expand All @@ -23,12 +19,13 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)> {
use rustc_middle::mir::BinOp::*;

let this = self.eval_context_ref();
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);

Ok(match bin_op {
Eq | Ne | Lt | Le | Gt | Ge => {
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
let size = self.pointer_size();
let size = this.pointer_size();
// Just compare the bits. ScalarPairs are compared lexicographically.
// We thus always compare pairs and simply fill scalars up with 0.
let left = match **left {
Expand All @@ -50,34 +47,75 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> {
Ge => left >= right,
_ => bug!(),
};
(ImmTy::from_bool(res, *self.tcx), false)
(ImmTy::from_bool(res, *this.tcx), false)
}

// Some more operations are possible with atomics.
// The return value always has the provenance of the *left* operand.
Add | Sub | BitOr | BitAnd | BitXor => {
assert!(left.layout.ty.is_unsafe_ptr());
assert!(right.layout.ty.is_unsafe_ptr());
let ptr = left.to_scalar().to_pointer(self)?;
let ptr = left.to_scalar().to_pointer(this)?;
// We do the actual operation with usize-typed scalars.
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
let left = ImmTy::from_uint(ptr.addr().bytes(), this.machine.layouts.usize);
let right = ImmTy::from_uint(
right.to_scalar().to_target_usize(self)?,
self.machine.layouts.usize,
right.to_scalar().to_target_usize(this)?,
this.machine.layouts.usize,
);
let (result, overflowing) = self.overflowing_binary_op(bin_op, &left, &right)?;
let (result, overflowing) = this.overflowing_binary_op(bin_op, &left, &right)?;
// Construct a new pointer with the provenance of `ptr` (the LHS).
let result_ptr = Pointer::new(
ptr.provenance,
Size::from_bytes(result.to_scalar().to_target_usize(self)?),
Size::from_bytes(result.to_scalar().to_target_usize(this)?),
);
(
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, self), left.layout),
ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, this), left.layout),
overflowing,
)
}

_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
_ => span_bug!(this.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
})
}

fn generate_nan<F1: Float + FloatConvert<F2>, F2: Float>(&self, inputs: &[F1]) -> F2 {
/// Make the given NaN a signaling NaN.
/// Returns `None` if this would not result in a NaN.
fn make_signaling<F: Float>(f: F) -> Option<F> {
// The quiet/signaling bit is the leftmost bit in the mantissa.
// That's position `PRECISION-1`, since `PRECISION` includes the fixed leading 1 bit,
// and then we subtract 1 more since this is 0-indexed.
let quiet_bit_mask = 1 << (F::PRECISION - 2);
// Unset the bit. Double-check that this wasn't the last bit set in the payload.
// (which would turn the NaN into an infinity).
let f = F::from_bits(f.to_bits() & !quiet_bit_mask);
if f.is_nan() { Some(f) } else { None }
}

let this = self.eval_context_ref();
let mut rand = this.machine.rng.borrow_mut();
// Assemble an iterator of possible NaNs: preferred, quieting propagation, unchanged propagation.
// On some targets there are more possibilities; for now we just generate those options that
// are possible everywhere.
let preferred_nan = F2::qnan(Some(0));
let nans = iter::once(preferred_nan)
.chain(inputs.iter().filter(|f| f.is_nan()).map(|&f| {
// Regular apfloat cast is quieting.
f.convert(&mut false).value
}))
.chain(inputs.iter().filter(|f| f.is_signaling()).filter_map(|&f| {
let f: F2 = f.convert(&mut false).value;
// We have to de-quiet this again for unchanged propagation.
make_signaling(f)
}));
// Pick one of the NaNs.
let nan = nans.choose(&mut *rand).unwrap();
// Non-deterministically flip the sign.
if rand.gen() {
// This will properly flip even for NaN.
-nan
} else {
nan
}
}
}
Loading

0 comments on commit 061c330

Please sign in to comment.