Skip to content

Commit

Permalink
Auto merge of rust-lang#119670 - cjgillot:gvn-arithmetic, r=oli-obk
Browse files Browse the repository at this point in the history
Fold arithmetic identities in GVN

Extracted from rust-lang#111344

This PR implements a few arithmetic folds for unary and binary operations.
This should take care of the missed optimizations introduced by rust-lang#116012.
  • Loading branch information
bors committed Jan 17, 2024
2 parents 6ed31ab + 0cc2102 commit 232aeaf
Show file tree
Hide file tree
Showing 37 changed files with 2,064 additions and 483 deletions.
181 changes: 175 additions & 6 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,20 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
Some(self.insert(Value::Constant { value, disambiguator }))
}

fn insert_bool(&mut self, flag: bool) -> VnIndex {
// Booleans are deterministic.
self.insert(Value::Constant { value: Const::from_bool(self.tcx, flag), disambiguator: 0 })
}

fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex {
self.insert_constant(Const::from_scalar(self.tcx, scalar, ty))
.expect("scalars are deterministic")
}

fn insert_tuple(&mut self, values: Vec<VnIndex>) -> VnIndex {
self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::from_u32(0), values))
}

#[instrument(level = "trace", skip(self), ret)]
fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
use Value::*;
Expand Down Expand Up @@ -767,10 +776,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}

// Operations.
Rvalue::Len(ref mut place) => {
let place = self.simplify_place_value(place, location)?;
Value::Len(place)
}
Rvalue::Len(ref mut place) => return self.simplify_len(place, location),
Rvalue::Cast(kind, ref mut value, to) => {
let from = value.ty(self.local_decls, self.tcx);
let value = self.simplify_operand(value, location)?;
Expand All @@ -785,17 +791,36 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
Value::Cast { kind, value, from, to }
}
Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
let ty = lhs.ty(self.local_decls, self.tcx);
let lhs = self.simplify_operand(lhs, location);
let rhs = self.simplify_operand(rhs, location);
Value::BinaryOp(op, lhs?, rhs?)
// Only short-circuit options after we called `simplify_operand`
// on both operands for side effect.
let lhs = lhs?;
let rhs = rhs?;
if let Some(value) = self.simplify_binary(op, false, ty, lhs, rhs) {
return Some(value);
}
Value::BinaryOp(op, lhs, rhs)
}
Rvalue::CheckedBinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
let ty = lhs.ty(self.local_decls, self.tcx);
let lhs = self.simplify_operand(lhs, location);
let rhs = self.simplify_operand(rhs, location);
Value::CheckedBinaryOp(op, lhs?, rhs?)
// Only short-circuit options after we called `simplify_operand`
// on both operands for side effect.
let lhs = lhs?;
let rhs = rhs?;
if let Some(value) = self.simplify_binary(op, true, ty, lhs, rhs) {
return Some(value);
}
Value::CheckedBinaryOp(op, lhs, rhs)
}
Rvalue::UnaryOp(op, ref mut arg) => {
let arg = self.simplify_operand(arg, location)?;
if let Some(value) = self.simplify_unary(op, arg) {
return Some(value);
}
Value::UnaryOp(op, arg)
}
Rvalue::Discriminant(ref mut place) => {
Expand Down Expand Up @@ -894,6 +919,150 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {

Some(self.insert(Value::Aggregate(ty, variant_index, fields)))
}

#[instrument(level = "trace", skip(self), ret)]
fn simplify_unary(&mut self, op: UnOp, value: VnIndex) -> Option<VnIndex> {
let value = match (op, self.get(value)) {
(UnOp::Not, Value::UnaryOp(UnOp::Not, inner)) => return Some(*inner),
(UnOp::Neg, Value::UnaryOp(UnOp::Neg, inner)) => return Some(*inner),
(UnOp::Not, Value::BinaryOp(BinOp::Eq, lhs, rhs)) => {
Value::BinaryOp(BinOp::Ne, *lhs, *rhs)
}
(UnOp::Not, Value::BinaryOp(BinOp::Ne, lhs, rhs)) => {
Value::BinaryOp(BinOp::Eq, *lhs, *rhs)
}
_ => return None,
};

Some(self.insert(value))
}

#[instrument(level = "trace", skip(self), ret)]
fn simplify_binary(
&mut self,
op: BinOp,
checked: bool,
lhs_ty: Ty<'tcx>,
lhs: VnIndex,
rhs: VnIndex,
) -> Option<VnIndex> {
// Floats are weird enough that none of the logic below applies.
let reasonable_ty =
lhs_ty.is_integral() || lhs_ty.is_bool() || lhs_ty.is_char() || lhs_ty.is_any_ptr();
if !reasonable_ty {
return None;
}

let layout = self.ecx.layout_of(lhs_ty).ok()?;

let as_bits = |value| {
let constant = self.evaluated[value].as_ref()?;
if layout.abi.is_scalar() {
let scalar = self.ecx.read_scalar(constant).ok()?;
scalar.to_bits(constant.layout.size).ok()
} else {
// `constant` is a wide pointer. Do not evaluate to bits.
None
}
};

// Represent the values as `Left(bits)` or `Right(VnIndex)`.
use Either::{Left, Right};
let a = as_bits(lhs).map_or(Right(lhs), Left);
let b = as_bits(rhs).map_or(Right(rhs), Left);
let result = match (op, a, b) {
// Neutral elements.
(BinOp::Add | BinOp::BitOr | BinOp::BitXor, Left(0), Right(p))
| (
BinOp::Add
| BinOp::BitOr
| BinOp::BitXor
| BinOp::Sub
| BinOp::Offset
| BinOp::Shl
| BinOp::Shr,
Right(p),
Left(0),
)
| (BinOp::Mul, Left(1), Right(p))
| (BinOp::Mul | BinOp::Div, Right(p), Left(1)) => p,
// Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size.
(BinOp::BitAnd, Right(p), Left(ones)) | (BinOp::BitAnd, Left(ones), Right(p))
if ones == layout.size.truncate(u128::MAX)
|| (layout.ty.is_bool() && ones == 1) =>
{
p
}
// Absorbing elements.
(BinOp::Mul | BinOp::BitAnd, _, Left(0))
| (BinOp::Rem, _, Left(1))
| (
BinOp::Mul | BinOp::Div | BinOp::Rem | BinOp::BitAnd | BinOp::Shl | BinOp::Shr,
Left(0),
_,
) => self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty),
// Attempt to simplify `x | ALL_ONES` to `ALL_ONES`.
(BinOp::BitOr, _, Left(ones)) | (BinOp::BitOr, Left(ones), _)
if ones == layout.size.truncate(u128::MAX)
|| (layout.ty.is_bool() && ones == 1) =>
{
self.insert_scalar(Scalar::from_uint(ones, layout.size), lhs_ty)
}
// Sub/Xor with itself.
(BinOp::Sub | BinOp::BitXor, a, b) if a == b => {
self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty)
}
// Comparison:
// - if both operands can be computed as bits, just compare the bits;
// - if we proved that both operands have the same value, we can insert true/false;
// - otherwise, do nothing, as we do not try to prove inequality.
(BinOp::Eq, Left(a), Left(b)) => self.insert_bool(a == b),
(BinOp::Eq, a, b) if a == b => self.insert_bool(true),
(BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b),
(BinOp::Ne, a, b) if a == b => self.insert_bool(false),
_ => return None,
};

if checked {
let false_val = self.insert_bool(false);
Some(self.insert_tuple(vec![result, false_val]))
} else {
Some(result)
}
}

fn simplify_len(&mut self, place: &mut Place<'tcx>, location: Location) -> Option<VnIndex> {
// Trivial case: we are fetching a statically known length.
let place_ty = place.ty(self.local_decls, self.tcx).ty;
if let ty::Array(_, len) = place_ty.kind() {
return self.insert_constant(Const::from_ty_const(*len, self.tcx));
}

let mut inner = self.simplify_place_value(place, location)?;

// The length information is stored in the fat pointer.
// Reborrowing copies length information from one pointer to the other.
while let Value::Address { place: borrowed, .. } = self.get(inner)
&& let [PlaceElem::Deref] = borrowed.projection[..]
&& let Some(borrowed) = self.locals[borrowed.local]
{
inner = borrowed;
}

// We have an unsizing cast, which assigns the length to fat pointer metadata.
if let Value::Cast { kind, from, to, .. } = self.get(inner)
&& let CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize) = kind
&& let Some(from) = from.builtin_deref(true)
&& let ty::Array(_, len) = from.ty.kind()
&& let Some(to) = to.builtin_deref(true)
&& let ty::Slice(..) = to.ty.kind()
{
return self.insert_constant(Const::from_ty_const(*len, self.tcx));
}

// Fallback: a symbolic `Len`.
Some(self.insert(Value::Len(inner)))
}
}

fn op_to_prop_const<'tcx>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
_2 = [const 0_u32, const 1_u32, const 2_u32, const 3_u32];
StorageLive(_3);
_3 = const 2_usize;
_4 = Len(_2);
- _4 = Len(_2);
- _5 = Lt(_3, _4);
- assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, _3) -> [success: bb1, unwind unreachable];
+ _5 = Lt(const 2_usize, _4);
+ assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, const 2_usize) -> [success: bb1, unwind unreachable];
+ _4 = const 4_usize;
+ _5 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 4_usize, const 2_usize) -> [success: bb1, unwind unreachable];
}

bb1: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
_2 = [const 0_u32, const 1_u32, const 2_u32, const 3_u32];
StorageLive(_3);
_3 = const 2_usize;
_4 = Len(_2);
- _4 = Len(_2);
- _5 = Lt(_3, _4);
- assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, _3) -> [success: bb1, unwind continue];
+ _5 = Lt(const 2_usize, _4);
+ assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, const 2_usize) -> [success: bb1, unwind continue];
+ _4 = const 4_usize;
+ _5 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 4_usize, const 2_usize) -> [success: bb1, unwind continue];
}

bb1: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
_2 = [const 0_u32, const 1_u32, const 2_u32, const 3_u32];
StorageLive(_3);
_3 = const 2_usize;
_4 = Len(_2);
- _4 = Len(_2);
- _5 = Lt(_3, _4);
- assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, _3) -> [success: bb1, unwind unreachable];
+ _5 = Lt(const 2_usize, _4);
+ assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, const 2_usize) -> [success: bb1, unwind unreachable];
+ _4 = const 4_usize;
+ _5 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 4_usize, const 2_usize) -> [success: bb1, unwind unreachable];
}

bb1: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
_2 = [const 0_u32, const 1_u32, const 2_u32, const 3_u32];
StorageLive(_3);
_3 = const 2_usize;
_4 = Len(_2);
- _4 = Len(_2);
- _5 = Lt(_3, _4);
- assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, _3) -> [success: bb1, unwind continue];
+ _5 = Lt(const 2_usize, _4);
+ assert(move _5, "index out of bounds: the length is {} but the index is {}", move _4, const 2_usize) -> [success: bb1, unwind continue];
+ _4 = const 4_usize;
+ _5 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", const 4_usize, const 2_usize) -> [success: bb1, unwind continue];
}

bb1: {
Expand Down
7 changes: 3 additions & 4 deletions tests/mir-opt/const_prop/boolean_identities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ pub fn test(x: bool, y: bool) -> bool {
// CHECK-LABEL: fn test(
// CHECK: debug a => [[a:_.*]];
// CHECK: debug b => [[b:_.*]];
// FIXME(cjgillot) simplify algebraic identity
// CHECK-NOT: [[a]] = const true;
// CHECK-NOT: [[b]] = const false;
// CHECK-NOT: _0 = const false;
// CHECK: [[a]] = const true;
// CHECK: [[b]] = const false;
// CHECK: _0 = const false;
let a = (y | true);
let b = (x & false);
a & b
Expand Down
12 changes: 7 additions & 5 deletions tests/mir-opt/const_prop/boolean_identities.test.GVN.diff
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@
StorageLive(_4);
_4 = _2;
- _3 = BitOr(move _4, const true);
+ _3 = BitOr(_2, const true);
+ _3 = const true;
StorageDead(_4);
- StorageLive(_5);
+ nop;
StorageLive(_6);
_6 = _1;
- _5 = BitAnd(move _6, const false);
+ _5 = BitAnd(_1, const false);
+ _5 = const false;
StorageDead(_6);
StorageLive(_7);
_7 = _3;
- _7 = _3;
+ _7 = const true;
StorageLive(_8);
_8 = _5;
- _8 = _5;
- _0 = BitAnd(move _7, move _8);
+ _0 = BitAnd(_3, _5);
+ _8 = const false;
+ _0 = const false;
StorageDead(_8);
StorageDead(_7);
- StorageDead(_5);
Expand Down
9 changes: 6 additions & 3 deletions tests/mir-opt/const_prop/boxes.main.GVN.panic-abort.diff
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

bb0: {
StorageLive(_1);
StorageLive(_2);
- StorageLive(_2);
+ nop;
StorageLive(_3);
- _4 = SizeOf(i32);
- _5 = AlignOf(i32);
Expand All @@ -39,8 +40,10 @@
StorageDead(_7);
_9 = (((_3.0: std::ptr::Unique<i32>).0: std::ptr::NonNull<i32>).0: *const i32);
_2 = (*_9);
_1 = Add(move _2, const 0_i32);
StorageDead(_2);
- _1 = Add(move _2, const 0_i32);
- StorageDead(_2);
+ _1 = _2;
+ nop;
drop(_3) -> [return: bb2, unwind unreachable];
}

Expand Down
9 changes: 6 additions & 3 deletions tests/mir-opt/const_prop/boxes.main.GVN.panic-unwind.diff
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

bb0: {
StorageLive(_1);
StorageLive(_2);
- StorageLive(_2);
+ nop;
StorageLive(_3);
- _4 = SizeOf(i32);
- _5 = AlignOf(i32);
Expand All @@ -39,8 +40,10 @@
StorageDead(_7);
_9 = (((_3.0: std::ptr::Unique<i32>).0: std::ptr::NonNull<i32>).0: *const i32);
_2 = (*_9);
_1 = Add(move _2, const 0_i32);
StorageDead(_2);
- _1 = Add(move _2, const 0_i32);
- StorageDead(_2);
+ _1 = _2;
+ nop;
drop(_3) -> [return: bb2, unwind: bb3];
}

Expand Down
2 changes: 1 addition & 1 deletion tests/mir-opt/const_prop/boxes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn main() {
// CHECK: debug x => [[x:_.*]];
// CHECK: (*{{_.*}}) = const 42_i32;
// CHECK: [[tmp:_.*]] = (*{{_.*}});
// CHECK: [[x]] = Add(move [[tmp]], const 0_i32);
// CHECK: [[x]] = [[tmp]];
let x = *(#[rustc_box]
Box::new(42))
+ 0;
Expand Down
Loading

0 comments on commit 232aeaf

Please sign in to comment.