Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXPERIMENT] Don't monomorphize things that are unused due to if <T as Trait>::CONST #91222

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
self.codegen_terminator(bx, bb, data.terminator());
}

pub fn codegen_block_as_unreachable(&mut self, bb: mir::BasicBlock) {
let mut bx = self.build_block(bb);
debug!("codegen_block_as_unreachable({:?})", bb);
bx.unreachable();
}

fn codegen_terminator(
&mut self,
mut bx: Bx,
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,19 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
// Apply debuginfo to the newly allocated locals.
fx.debug_introduce_locals(&mut bx);

let reachable_blocks = mir.reachable_blocks_in_mono(cx.tcx(), instance);

// Codegen the body of each block using reverse postorder
// FIXME(eddyb) reuse RPO iterator between `analysis` and this.
for (bb, _) in traversal::reverse_postorder(&mir) {
fx.codegen_block(bb);
if reachable_blocks.contains(bb) {
fx.codegen_block(bb);
} else {
// This may have references to things we didn't monomorphize, so we
// don't actually codegen the body. We still create the block so
// terminators in other blocks can reference it without worry.
fx.codegen_block_as_unreachable(bb);
}
}

// For backends that support CFI using type membership (i.e., testing whether a given pointer
Expand Down
72 changes: 70 additions & 2 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::ty::fold::{TypeFoldable, TypeFolder, TypeVisitor};
use crate::ty::print::{FmtPrinter, Printer};
use crate::ty::subst::{Subst, SubstsRef};
use crate::ty::{self, List, Ty, TyCtxt};
use crate::ty::{AdtDef, InstanceDef, Region, ScalarInt, UserTypeAnnotationIndex};
use crate::ty::{AdtDef, Instance, InstanceDef, Region, ScalarInt, UserTypeAnnotationIndex};
use rustc_hir::def::{CtorKind, Namespace};
use rustc_hir::def_id::{DefId, CRATE_DEF_INDEX};
use rustc_hir::{self, GeneratorKind};
Expand All @@ -23,7 +23,7 @@ pub use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::graph::dominators::{dominators, Dominators};
use rustc_data_structures::graph::{self, GraphSuccessors};
use rustc_index::bit_set::BitMatrix;
use rustc_index::bit_set::{BitMatrix, BitSet};
use rustc_index::vec::{Idx, IndexVec};
use rustc_serialize::{Decodable, Encodable};
use rustc_span::symbol::Symbol;
Expand Down Expand Up @@ -517,6 +517,73 @@ impl<'tcx> Body<'tcx> {
pub fn generator_kind(&self) -> Option<GeneratorKind> {
self.generator.as_ref().map(|generator| generator.generator_kind)
}

/// Finds which basic blocks are actually reachable for a specific
/// monomorphization of this body.
///
/// This is allowed to have false positives; just because this says a block
/// is reachable doesn't mean that's necessarily true. It's thus always
/// legal for this to return a filled set.
///
/// Regardless, the [`BitSet::domain_size`] of the returned set will always
/// exactly match the number of blocks in the body so that `contains`
/// checks can be done without worrying about panicking.
///
/// The main case this supports is filtering out `if <T as Trait>::CONST`
/// bodies that can't be removed in generic MIR, but *can* be removed once
/// the specific `T` is known.
///
/// This is used in the monomorphization collector as well as in codegen.
pub fn reachable_blocks_in_mono(
&self,
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
) -> BitSet<BasicBlock> {
if instance.substs.is_noop() {
// If it's non-generic, then mir-opt const prop has already run, meaning it's
// probably not worth doing any further filtering. So call everything reachable.
return BitSet::new_filled(self.basic_blocks().len());
}

let mut set = BitSet::new_empty(self.basic_blocks().len());
self.reachable_blocks_in_mono_from(tcx, instance, &mut set, START_BLOCK);
set
}

fn reachable_blocks_in_mono_from(
&self,
tcx: TyCtxt<'tcx>,
instance: Instance<'tcx>,
set: &mut BitSet<BasicBlock>,
bb: BasicBlock,
) {
if !set.insert(bb) {
return;
}

let data = &self.basic_blocks()[bb];

if let TerminatorKind::SwitchInt {
discr: Operand::Constant(constant),
switch_ty,
targets,
} = &data.terminator().kind
{
let env = ty::ParamEnv::reveal_all();
let mono_literal =
instance.subst_mir_and_normalize_erasing_regions(tcx, env, constant.literal);
if let Some(bits) = mono_literal.try_eval_bits(tcx, env, switch_ty) {
let target = targets.target_for_value(bits);
return self.reachable_blocks_in_mono_from(tcx, instance, set, target);
} else {
bug!("Couldn't evaluate constant {:?} in mono {:?}", constant, instance);
}
}

for &target in data.terminator().successors() {
self.reachable_blocks_in_mono_from(tcx, instance, set, target);
}
}
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, TyEncodable, TyDecodable, HashStable)]
Expand Down Expand Up @@ -1504,6 +1571,7 @@ impl Statement<'_> {
}

/// Changes a statement to a nop and returns the original statement.
#[must_use = "If you don't need the statement, use `make_nop` instead"]
pub fn replace_nop(&mut self) -> Self {
Statement {
source_info: self.source_info,
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ impl SwitchTargets {
pub fn all_targets_mut(&mut self) -> &mut [BasicBlock] {
&mut self.targets
}

/// Finds the `BasicBlock` to which this `SwitchInt` will branch given the
/// specific value. This cannot fail, as it'll return the `otherwise`
/// branch if there's not a specific match for the value.
pub fn target_for_value(&self, value: u128) -> BasicBlock {
self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise())
}
}

pub struct SwitchTargetsIter<'a> {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ mod shim;
mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod simplify_if_const;
mod simplify_try;
mod uninhabited_enum_branching;
mod unreachable_prop;
Expand Down Expand Up @@ -456,6 +457,7 @@ fn run_post_borrowck_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tc

let post_borrowck_cleanup: &[&dyn MirPass<'tcx>] = &[
// Remove all things only needed by analysis
&simplify_if_const::SimplifyIfConst,
&simplify_branches::SimplifyBranches::new("initial"),
&remove_noop_landing_pads::RemoveNoopLandingPads,
&cleanup_post_borrowck::CleanupNonCodegenStatements,
Expand Down
11 changes: 2 additions & 9 deletions compiler/rustc_mir_transform/src/simplify_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@ impl<'tcx> MirPass<'tcx> for SimplifyBranches {
} => {
let constant = c.literal.try_eval_bits(tcx, param_env, switch_ty);
if let Some(constant) = constant {
let otherwise = targets.otherwise();
let mut ret = TerminatorKind::Goto { target: otherwise };
for (v, t) in targets.iter() {
if v == constant {
ret = TerminatorKind::Goto { target: t };
break;
}
}
ret
let target = targets.target_for_value(constant);
TerminatorKind::Goto { target }
} else {
continue;
}
Expand Down
76 changes: 76 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_if_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! A pass that simplifies branches when their condition is known.

use crate::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

/// The lowering for `if CONST` produces
/// ```
/// _1 = Const(...);
/// switchInt (move _1)
/// ```
/// so this pass replaces that with
/// ```
/// switchInt (Const(...))
/// ```
/// so that further MIR consumers can special-case it more easily.
///
/// Unlike ConstProp, this supports generic constants too, not just concrete ones.
pub struct SimplifyIfConst;

impl<'tcx> MirPass<'tcx> for SimplifyIfConst {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
for block in body.basic_blocks_mut() {
simplify_assign_move_switch(tcx, block);
}
}
}

fn simplify_assign_move_switch(tcx: TyCtxt<'_>, block: &mut BasicBlockData<'_>) {
let Some(Terminator { kind: TerminatorKind::SwitchInt { discr: switch_desc, ..}, ..}) =
&mut block.terminator
else { return };

let &mut Operand::Move(switch_place) = &mut*switch_desc
else { return };

let Some(switch_local) = switch_place.as_local()
else { return };

let Some(last_statement) = block.statements.last_mut()
else { return };

let StatementKind::Assign(boxed_place_rvalue) = &last_statement.kind
else { return };

let Some(assigned_local) = boxed_place_rvalue.0.as_local()
else { return };

if switch_local != assigned_local {
return;
}

if !matches!(boxed_place_rvalue.1, Rvalue::Use(Operand::Constant(_))) {
return;
}

let should_optimize = tcx.consider_optimizing(|| {
format!(
"SimplifyBranches - Assignment: {:?} SourceInfo: {:?}",
boxed_place_rvalue, last_statement.source_info
)
});

if should_optimize {
let Some(last_statement) = block.statements.pop()
else { bug!("Somehow the statement disappeared?"); };

let StatementKind::Assign(boxed_place_rvalue) = last_statement.kind
else { bug!("Somehow it's not an assignment any more?"); };

let Rvalue::Use(assigned_constant @ Operand::Constant(_)) = boxed_place_rvalue.1
else { bug!("Somehow it's not a use of a constant any more?"); };

*switch_desc = assigned_constant;
}
}
16 changes: 14 additions & 2 deletions compiler/rustc_monomorphize/src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ use rustc_hir as hir;
use rustc_hir::def_id::{DefId, DefIdMap, LocalDefId, LOCAL_CRATE};
use rustc_hir::itemlikevisit::ItemLikeVisitor;
use rustc_hir::lang_items::LangItem;
use rustc_index::bit_set::GrowableBitSet;
use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_middle::mir::interpret::{AllocId, ConstValue};
use rustc_middle::mir::interpret::{ErrorHandled, GlobalAlloc, Scalar};
use rustc_middle::mir::mono::{InstantiationMode, MonoItem};
Expand Down Expand Up @@ -608,6 +608,7 @@ struct MirNeighborCollector<'a, 'tcx> {
body: &'a mir::Body<'tcx>,
output: &'a mut Vec<Spanned<MonoItem<'tcx>>>,
instance: Instance<'tcx>,
reachable_blocks: BitSet<mir::BasicBlock>,
}

impl<'a, 'tcx> MirNeighborCollector<'a, 'tcx> {
Expand All @@ -625,6 +626,14 @@ impl<'a, 'tcx> MirNeighborCollector<'a, 'tcx> {
}

impl<'a, 'tcx> MirVisitor<'tcx> for MirNeighborCollector<'a, 'tcx> {
fn visit_basic_block_data(&mut self, block: mir::BasicBlock, data: &mir::BasicBlockData<'tcx>) {
if self.reachable_blocks.contains(block) {
self.super_basic_block_data(block, data);
} else {
debug!("skipping mono-unreachable basic block {:?}", block);
}
}

fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: Location) {
debug!("visiting rvalue {:?}", *rvalue);

Expand Down Expand Up @@ -1395,7 +1404,10 @@ fn collect_neighbours<'tcx>(
debug!("collect_neighbours: {:?}", instance.def_id());
let body = tcx.instance_mir(instance.def);

MirNeighborCollector { tcx, body: &body, output, instance }.visit_body(&body);
let reachable_blocks = body.reachable_blocks_in_mono(tcx, instance);
let mut collector =
MirNeighborCollector { tcx, body: &body, output, instance, reachable_blocks };
collector.visit_body(&body);
}

fn collect_const_value<'tcx>(
Expand Down
22 changes: 16 additions & 6 deletions library/alloc/src/vec/source_iter_marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ where
// a) no ZSTs as there would be no allocation to reuse and pointer arithmetic would panic
// b) size match as required by Alloc contract
// c) alignments match as required by Alloc contract
if mem::size_of::<T>() == 0
|| mem::size_of::<T>()
!= mem::size_of::<<<I as SourceIter>::Source as AsIntoIter>::Item>()
|| mem::align_of::<T>()
!= mem::align_of::<<<I as SourceIter>::Source as AsIntoIter>::Item>()
{
if <I as LayoutMatcher>::MATCHES {
// fallback to more generic implementations
return SpecFromIterNested::from_iter(iterator);
}
Expand Down Expand Up @@ -154,3 +149,18 @@ where
len
}
}

trait LayoutMatcher {
type IN;
const MATCHES: bool;
}

impl<I, OUT> LayoutMatcher for I
where
I: Iterator<Item = OUT> + SourceIter<Source: AsIntoIter>,
{
type IN = <<I as SourceIter>::Source as AsIntoIter>::Item;
const MATCHES: bool = mem::size_of::<OUT>() == 0
|| mem::size_of::<OUT>() != mem::size_of::<Self::IN>()
|| mem::align_of::<OUT>() != mem::align_of::<Self::IN>();
}
9 changes: 8 additions & 1 deletion library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,13 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
#[inline]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub(crate) const unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
trait TypeSizeCheck {
const IS_CHUNK_SIZE_OR_LARGER: bool;
}
impl<T> TypeSizeCheck for T {
const IS_CHUNK_SIZE_OR_LARGER: bool = mem::size_of::<T>() >= 32;
}

// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
// block optimization in `swap_nonoverlapping_bytes` is hard to rewrite back
Expand All @@ -442,7 +449,7 @@ pub(crate) const unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
{
// Only apply the block optimization in `swap_nonoverlapping_bytes` for types
// at least as large as the block size, to avoid pessimizing codegen.
if mem::size_of::<T>() >= 32 {
if T::IS_CHUNK_SIZE_OR_LARGER {
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
unsafe { swap_nonoverlapping(x, y, 1) };
return;
Expand Down
Loading