Skip to content

Commit

Permalink
Auto merge of #79209 - spastorino:trait-inheritance-self, r=nikomatsakis
Browse files Browse the repository at this point in the history
Allow Trait inheritance with cycles on associated types

Fixes #35237

r? `@nikomatsakis`

cc `@estebank`
  • Loading branch information
bors committed Nov 29, 2020
2 parents b776d1c + ada7c1f commit 349b3b3
Show file tree
Hide file tree
Showing 28 changed files with 498 additions and 151 deletions.
34 changes: 33 additions & 1 deletion compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use smallvec::smallvec;

use crate::traits::{Obligation, ObligationCause, PredicateObligation};
use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_middle::ty::outlives::Component;
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
use rustc_span::symbol::Ident;

pub fn anonymize_predicate<'tcx>(
tcx: TyCtxt<'tcx>,
Expand Down Expand Up @@ -287,6 +288,37 @@ pub fn transitive_bounds<'tcx>(
elaborate_trait_refs(tcx, bounds).filter_to_traits()
}

/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
/// define the given associated type `assoc_name`. It uses the
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
tcx: TyCtxt<'tcx>,
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
assoc_name: Ident,
) -> FxIndexSet<ty::PolyTraitRef<'tcx>> {
let mut stack: Vec<_> = bounds.collect();
let mut trait_refs = FxIndexSet::default();

while let Some(trait_ref) = stack.pop() {
if trait_refs.insert(trait_ref) {
let super_predicates =
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
for (super_predicate, _) in super_predicates.predicates {
let bound_predicate = super_predicate.bound_atom();
let subst_predicate = super_predicate
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
stack.push(binder.value);
}
}
}
}

trait_refs
}

///////////////////////////////////////////////////////////////////////////
// Other
///////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,23 @@ rustc_queries! {
/// full predicates are available (note that supertraits have
/// additional acyclicity requirements).
query super_predicates_of(key: DefId) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
desc { |tcx| "computing the super predicates of `{}`", tcx.def_path_str(key) }
}

/// The `Option<Ident>` is the name of an associated type. If it is `None`, then this query
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
/// subset of super-predicates that reference traits that define the given associated type.
/// This is used to avoid cycles in resolving types like `T::Item`.
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the super traits of `{}`{}",
tcx.def_path_str(key.0),
if let Some(assoc_name) = key.1 { format!(" with associated type name `{}`", assoc_name) } else { "".to_string() },
}
}

/// To avoid cycles within the predicates of a single item we compute
/// per-type-parameter predicates for resolving `T::AssocTy`.
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the bounds for type parameter `{}`", {
let id = tcx.hir().local_def_id_to_hir_id(key.1);
tcx.hir().ty_param_name(id)
Expand Down
38 changes: 37 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use rustc_session::config::{BorrowckMode, CrateType, OutputFilenames};
use rustc_session::lint::{Level, Lint};
use rustc_session::Session;
use rustc_span::source_map::MultiSpan;
use rustc_span::symbol::{kw, sym, Symbol};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{Layout, TargetDataLayout, VariantIdx};
use rustc_target::spec::abi;
Expand Down Expand Up @@ -2085,6 +2085,42 @@ impl<'tcx> TyCtxt<'tcx> {
self.mk_fn_ptr(sig.map_bound(|sig| ty::FnSig { unsafety: hir::Unsafety::Unsafe, ..sig }))
}

/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
self.super_traits_of(trait_def_id).any(|trait_did| {
self.associated_items(trait_did)
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
.is_some()
})
}

/// Computes the def-ids of the transitive super-traits of `trait_def_id`. This (intentionally)
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
/// to identify which traits may define a given associated type to help avoid cycle errors.
/// Returns a `DefId` iterator.
fn super_traits_of(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
let mut set = FxHashSet::default();
let mut stack = vec![trait_def_id];

set.insert(trait_def_id);

iter::from_fn(move || -> Option<DefId> {
let trait_did = stack.pop()?;
let generic_predicates = self.super_predicates_of(trait_did);

for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}

Some(trait_did)
})
}

/// Given a closure signature, returns an equivalent fn signature. Detuples
/// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then
/// you would get a `fn(u32, i32)`.
Expand Down
24 changes: 23 additions & 1 deletion compiler/rustc_middle/src/ty/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
use crate::ty::{self, Ty, TyCtxt};
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
use rustc_query_system::query::DefaultCacheSelector;
use rustc_span::symbol::Symbol;
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};

/// The `Key` trait controls what types can legally be used as the key
Expand Down Expand Up @@ -149,6 +149,28 @@ impl Key for (LocalDefId, DefId) {
}
}

impl Key for (DefId, Option<Ident>) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
tcx.def_span(self.0)
}
}

impl Key for (DefId, LocalDefId, Ident) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.1.default_span(tcx)
}
}

impl Key for (CrateNum, DefId) {
type CacheSelector = DefaultCacheSelector;

Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub use self::util::{
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
SupertraitDefIds, Supertraits,
};

pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;
Expand Down
68 changes: 57 additions & 11 deletions compiler/rustc_typeck/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {

fn default_constness_for_trait_bounds(&self) -> Constness;

/// Returns predicates in scope of the form `X: Foo`, where `X` is
/// a type parameter `X` with the given id `def_id`. This is a
/// subset of the full set of predicates.
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
/// is a type parameter `X` with the given id `def_id` and T
/// matches `assoc_name`. This is a subset of the full set of
/// predicates.
///
/// This is used for one specific purpose: resolving "short-hand"
/// associated type references like `T::Item`. In principle, we
Expand All @@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
/// but this can lead to cycle errors. The problem is that we have
/// to do this resolution *in order to create the predicates in
/// the first place*. Hence, we have this "special pass".
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
fn get_type_parameter_bounds(
&self,
span: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx>;

/// Returns the lifetime to use when a lifetime is omitted (and not elided).
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
Expand Down Expand Up @@ -762,7 +768,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
}

// Returns `true` if a bounds list includes `?Sized`.
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
let tcx = self.tcx();

// Try to find an unbound in bounds.
Expand Down Expand Up @@ -820,7 +826,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
fn add_bounds(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
ast_bounds: &[&hir::GenericBound<'_>],
bounds: &mut Bounds<'tcx>,
) {
let constness = self.default_constness_for_trait_bounds();
Expand All @@ -835,7 +841,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
.instantiate_lang_item_trait_ref(
lang_item, span, hir_id, args, param_ty, bounds,
*lang_item, *span, *hir_id, args, param_ty, bounds,
),
hir::GenericBound::Outlives(ref l) => {
bounds.region_bounds.push((self.ast_region_to_region(l, None), l.span))
Expand Down Expand Up @@ -866,6 +872,42 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.compute_bounds_inner(param_ty, &ast_bounds, sized_by_default, span)
}

/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
/// named `assoc_name` into ty::Bounds. Ignore the rest.
pub fn compute_bounds_that_match_assoc_type(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
assoc_name: Ident,
) -> Bounds<'tcx> {
let mut result = Vec::new();

for ast_bound in ast_bounds {
if let Some(trait_ref) = ast_bound.trait_ref() {
if let Some(trait_did) = trait_ref.trait_def_id() {
if self.tcx().trait_may_define_assoc_type(trait_did, assoc_name) {
result.push(ast_bound);
}
}
}
}

self.compute_bounds_inner(param_ty, &result, sized_by_default, span)
}

fn compute_bounds_inner(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[&hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let mut bounds = Bounds::default();

Expand Down Expand Up @@ -1035,7 +1077,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
// Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
// parameter to have a skipped binder.
let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
self.add_bounds(param_ty, ast_bounds, bounds);
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.add_bounds(param_ty, &ast_bounds, bounds);
}
}
Ok(())
Expand Down Expand Up @@ -1352,21 +1395,24 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty_param_def_id, assoc_name, span,
);

let predicates =
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
let predicates = &self
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
.predicates;

debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);

let param_hir_id = tcx.hir().local_def_id_to_hir_id(ty_param_def_id);
let param_name = tcx.hir().ty_param_name(param_hir_id);
self.one_bound_for_assoc_type(
|| {
traits::transitive_bounds(
traits::transitive_bounds_that_define_assoc_type(
tcx,
predicates.iter().filter_map(|(p, _)| {
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
}),
assoc_name,
)
.into_iter()
},
|| param_name.to_string(),
assoc_name,
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::GenericArgKind;
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};

Expand Down Expand Up @@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
}
}

fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
fn get_type_parameter_bounds(
&self,
_: Span,
def_id: DefId,
_: Ident,
) -> ty::GenericPredicates<'tcx> {
let tcx = self.tcx;
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
let item_id = tcx.hir().ty_param_owner(hir_id);
Expand Down
Loading

0 comments on commit 349b3b3

Please sign in to comment.