From a1659576fa6ef8b0236dfdfa16eb98f46f4e9d44 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 5 Jul 2025 13:07:09 +0000 Subject: [PATCH 1/3] WIP typetree impl --- .../rustc_ast/src/expand/autodiff_attrs.rs | 12 +- compiler/rustc_builtin_macros/src/autodiff.rs | 15 +- compiler/rustc_builtin_macros/src/lib.rs | 1 + compiler/rustc_builtin_macros/src/typetree.rs | 330 ++++++++++++++++++ .../src/builder/autodiff.rs | 33 +- compiler/rustc_codegen_llvm/src/lib.rs | 33 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 + compiler/rustc_codegen_llvm/src/typetree.rs | 33 ++ 8 files changed, 451 insertions(+), 9 deletions(-) create mode 100644 compiler/rustc_builtin_macros/src/typetree.rs create mode 100644 compiler/rustc_codegen_llvm/src/typetree.rs diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 2f918faaf752b..b615398b4ed09 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -9,6 +9,7 @@ use std::str::FromStr; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; +use crate::expand::typetree::TypeTree; /// Forward and Reverse Mode are well known names for automatic differentiation implementations. /// Enzyme does support both, but with different semantics, see DiffActivity. The First variants @@ -85,6 +86,9 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + // --- TypeTree support --- + pub inputs: Vec, + pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -112,6 +116,10 @@ impl AutoDiffAttrs { pub fn has_primal_ret(&self) -> bool { matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) } + /// New constructor for type tree support + pub fn into_item(self, source: String, target: String, inputs: Vec, output: TypeTree) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self, inputs, output } + } } impl DiffMode { @@ -284,6 +292,8 @@ impl AutoDiffAttrs { impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs) + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) } } diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c784477833279..dd5f3d5aa3237 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -11,7 +11,9 @@ mod llvm_enzyme { AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity, }; + use rustc_ast::expand::typetree::{TypeTree, Type, Kind}; use rustc_ast::ptr::P; + use crate::typetree::construct_typetree_from_fnsig; use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; @@ -324,6 +326,17 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); + // Construct real type trees from function signature + let (inputs, output) = construct_typetree_from_fnsig(&sig); + + // Use the new into_item method to construct the AutoDiffItem + let autodiff_item = x.clone().into_item( + primal.to_string(), + first_ident(&meta_item_vec[0]).to_string(), + inputs, + output, + ); + let n_active: u32 = x .input_activity .iter() @@ -1045,5 +1058,3 @@ mod llvm_enzyme { (d_sig, new_inputs, idents, false) } } - -pub(crate) use llvm_enzyme::{expand_forward, expand_reverse}; diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index 0594f7e86c333..f5dc409fde4d1 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -51,6 +51,7 @@ mod pattern_type; mod source_util; mod test; mod trace_macros; +mod typetree; pub mod asm; pub mod cmdline_attrs; diff --git a/compiler/rustc_builtin_macros/src/typetree.rs b/compiler/rustc_builtin_macros/src/typetree.rs new file mode 100644 index 0000000000000..f33efe3b22b06 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/typetree.rs @@ -0,0 +1,330 @@ +use rustc_ast as ast; +use rustc_ast::FnRetTy; +use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; +use rustc_middle::ty::{Ty, TyCtxt, ParamEnv, ParamEnvAnd, Adt}; +use rustc_middle::ty::layout::{FieldsShape, LayoutOf}; +use rustc_middle::hir; +use rustc_span::Span; +use rustc_ast::expand::autodiff_attrs::DiffActivity; + +#[cfg(llvm_enzyme)] +pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = vec![]; + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; + return TypeTree(vec![tt]); +} + +// This function combines three tasks. To avoid traversing each type 3x, we combine them. +// 1. Create a TypeTree from a Ty. This is the main task. +// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM +// lowering. E.g. fat ptr are going to introduce an extra int. +// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an +// autodiff macro on top). Here we want to make sure that shadows are mutable internally. +// We know the outermost ref/ptr indirection is mutability - we generate it like that. +// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. +// Not an error, becaues it only causes issues if they are actually read, which we don't check +// yet. We should add such analysis to relibably either issue an error or accept without warning. +// If there only were some reasearch to do that... +#[cfg(llvm_enzyme)] +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec, span: Option) -> FncTree { + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + let mut visited = vec![]; + let mut args = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + // We care about safety checks, if an argument get's duplicated and we write into the + // shadow. That's equivalent to Duplicated or DuplicatedOnly. + let safety = if !da.is_empty() { + assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); + // If we have Activities, we also have spans + assert!(span.is_some()); + match da[i] { + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, + _ => false, + } + } else { + false + }; + + visited.clear(); + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + if inner_ty.is_slice() { + // We know that the lenght will be passed as extra arg. + let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + args.push(TypeTree(vec![tt])); + let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; + args.push(TypeTree(vec![i64_tt])); + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly | DiffActivity::Dual | + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated + => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + trace!("ABI MATCHING!"); + continue; + } + } + let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); + args.push(arg_tt); + } + + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } + + visited.clear(); + let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); + + FncTree { args, ret } +} + + +// Error type for warnings +#[derive(Debug)] +pub struct AutodiffUnsafeInnerConstRef { + pub span: Span, + pub ty: String, +} + +#[cfg(llvm_enzyme)] +fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>, span: Option) -> TypeTree { + if depth > 20 { + trace!("depth > 20 for ty: {}", &ty); + } + if visited.contains(&ty) { + // recursive type + trace!("recursive type: {}", &ty); + return TypeTree::new(); + } + visited.push(ty); + + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + + let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); + let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; + let inner_ty = inner_ty_and_mut.ty; + + // Now account for inner mutability. + if !is_mut && depth > 0 && safety { + let ptr_ty: String = if ty.is_ref() { + "ref" + } else if ty.is_unsafe_ptr() { + "ptr" + } else { + assert!(ty.is_box()); + "box" + }.to_string(); + + // If we have mutability, we also have a span + assert!(span.is_some()); + let span = span.unwrap(); + + tcx.sess + .dcx() + .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); + } + + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + visited.pop(); + return TypeTree(vec![tt]); + } + + if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { + visited.pop(); + return TypeTree::new(); + } + + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + } else { + panic!("scalar that is neither integral nor floating point"); + }; + visited.pop(); + return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() && !ty.is_simd() { + let adt_def = ty.ty_adt_def().unwrap(); + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + // Manuel TODO: + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later + FieldsShape::Union(_) => {return TypeTree::new();}, + FieldsShape::Primitive => {return TypeTree::new();}, + }; + + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + visited.pop(); + let ret_tt = TypeTree(fields); + return ret_tt; + } else if adt_def.is_enum() { + // Enzyme can't represent enums, so let it figure it out itself, without seeeding + // typetree + //unimplemented!("adt that is an enum"); + } else { + //let ty_name = tcx.def_path_debug_str(adt_def.did()); + //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); + } + } + + if ty.is_simd() { + trace!("simd"); + let (_size, inner_ty) = ty.simd_size_and_type(tcx); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + // TODO + visited.pop(); + return TypeTree::new(); + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + if (*count as usize) == 0 { + return TypeTree::new(); + } + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + visited.pop(); + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + visited.pop(); + return subtt; + } + + visited.pop(); + TypeTree::new() +} + +// AST-based type tree construction (simplified fallback) +#[cfg(llvm_enzyme)] +pub fn construct_typetree_from_ty(ty: &ast::Ty) -> TypeTree { + // For now, return empty type tree to let Enzyme figure out layout + // In a full implementation, we'd need to convert AST types to Ty<'tcx> + // and use the layout-based approach from the old code + TypeTree::new() +} + +#[cfg(llvm_enzyme)] +pub fn construct_typetree_from_fnsig(sig: &ast::FnSig) -> (Vec, TypeTree) { + // For now, return empty type trees + // This will be replaced with proper layout-based construction + let inputs: Vec = sig.decl.inputs.iter() + .map(|_| TypeTree::new()) + .collect(); + + let output = match &sig.decl.output { + FnRetTy::Default(_) => TypeTree::new(), + FnRetTy::Ty(_) => TypeTree::new(), + }; + + (inputs, output) +} diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b07d9a5cfca8c..a709f528e9fba 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,6 +1,7 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::common::TypeKind; @@ -16,8 +17,10 @@ use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True}; +use crate::typetree::to_enzyme_typetree; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm, DiffTypeTree}; +use rustc_data_structures::fx::FxHashMap; fn get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; @@ -294,6 +297,7 @@ fn generate_enzyme_call<'ll>( fn_to_diff: &'ll Value, outer_fn: &'ll Value, attrs: AutoDiffAttrs, + fnc_tree: Option, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -361,6 +365,15 @@ fn generate_enzyme_call<'ll>( let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); attributes::apply_to_llfn(ad_fn, Function, &[attr]); + // TODO(KMJ-007): Add type tree metadata if available + // This requires adding CreateTypeTreeAttribute to LLVM bindings + // if let Some(tree) = fnc_tree { + // let data_layout = cx.data_layout(); + // let enzyme_tree = to_enzyme_typetree(tree, data_layout, cx.llcx); + // let tt_attr = llvm::CreateTypeTreeAttribute(cx.llcx, enzyme_tree); + // attributes::apply_to_llfn(ad_fn, Function, &[tt_attr]); + // } + // We add a made-up attribute just such that we can recognize it after AD to update // (no)-inline attributes. We'll then also remove this attribute. let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); @@ -461,6 +474,7 @@ pub(crate) fn differentiate<'ll>( module: &'ll ModuleCodegen, cgcx: &CodegenContext, diff_items: Vec, + typetrees: FxHashMap, _config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { @@ -505,7 +519,22 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // Use type trees from the typetrees map if available, otherwise construct from item + let fnc_tree = if let Some(diff_tt) = typetrees.get(&item.source) { + Some(FncTree { + inputs: diff_tt.input_tt.clone(), + output: diff_tt.ret_tt.clone(), + }) + } else if !item.inputs.is_empty() || !item.output.0.is_empty() { + Some(FncTree { + inputs: item.inputs.clone(), + output: item.output.clone(), + }) + } else { + None + }; + + generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone(), fnc_tree); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index cdfffbe47bfa5..541cc6ca8718f 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -27,6 +27,7 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use context::SimpleCx; use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +use llvm::TypeTree; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; use rustc_ast::expand::autodiff_attrs::AutoDiffItem; @@ -36,7 +37,7 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen, TargetConfig}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; @@ -74,6 +75,7 @@ mod llvm_util; mod mono_item; mod type_; mod type_of; +mod typetree; mod va_arg; mod value; @@ -159,6 +161,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; fn print_pass_timings(&self) { let timings = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintPassTimings(s) }).unwrap(); print!("{timings}"); @@ -232,13 +235,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { cgcx: &CodegenContext, module: &ModuleCodegen, diff_fncs: Vec, + typetrees: FxHashMap, config: &ModuleConfig, ) -> Result<(), FatalError> { if cgcx.lto != Lto::Fat { let dcx = cgcx.create_dcx(); return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); } - builder::autodiff::differentiate(module, cgcx, diff_fncs, config) + builder::autodiff::differentiate(module, cgcx, diff_fncs, typetrees, config) + } + + // The typetrees contain all information, their order therefore is irrelevant. + #[allow(rustc::potential_query_instability)] + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() } } @@ -386,6 +396,13 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -393,6 +410,7 @@ pub struct ModuleLlvm { // This field is `ManuallyDrop` because it is important that the `TargetMachine` // is disposed prior to the `Context` being disposed otherwise UAFs can occur. tm: ManuallyDrop, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -407,6 +425,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_target_machine(tcx, mod_name)), + typetrees: Default::default(), } } } @@ -418,7 +437,8 @@ impl ModuleLlvm { ModuleLlvm { llmod_raw, llcx, - tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess, false)), + tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess)), + typetrees: Default::default(), } } } @@ -440,7 +460,12 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 91ada856d5977..5e0cc8a0b6316 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2670,4 +2670,7 @@ unsafe extern "C" { pub(crate) fn LLVMRustSetNoSanitizeAddress(Global: &Value); pub(crate) fn LLVMRustSetNoSanitizeHWAddress(Global: &Value); + + // Type Tree Attribute Functions + pub fn CreateTypeTreeAttribute<'a>(llcx: &'a Context, typetree: &'a TypeTree) -> &'a Attribute; } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..3d688f443524c --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} \ No newline at end of file From 1c3075b71f16e3d6c67f623fe90b8a0160de454a Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 12 Jul 2025 11:19:37 +0000 Subject: [PATCH 2/3] it should work now Signed-off-by: Karan Janthe --- compiler/rustc_codegen_llvm/src/builder.rs | 125 +++++++- .../src/builder/autodiff.rs | 9 +- compiler/rustc_codegen_llvm/src/lib.rs | 72 ++--- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 285 ++++++++++++++++++ compiler/rustc_codegen_ssa/src/back/lto.rs | 5 +- compiler/rustc_codegen_ssa/src/mir/block.rs | 8 +- .../rustc_codegen_ssa/src/mir/intrinsic.rs | 19 +- compiler/rustc_codegen_ssa/src/mir/mod.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 20 +- compiler/rustc_codegen_ssa/src/mir/place.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/rvalue.rs | 16 +- .../rustc_codegen_ssa/src/traits/builder.rs | 14 +- .../rustc_codegen_ssa/src/traits/write.rs | 29 +- 13 files changed, 494 insertions(+), 112 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index d0aa7320b4b68..8ec1b167dae04 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -28,6 +28,9 @@ use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; use tracing::{debug, instrument}; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; +use crate::typetree::to_enzyme_typetree; + use crate::abi::FnAbiLlvmExt; use crate::attributes; use crate::common::Funclet; @@ -548,13 +551,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); let align = align.min(self.cx().tcx.sess.target.max_reliable_alignment()); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load + }; + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, load, tt); } + load } fn volatile_load(&mut self, ty: &'ll Type, ptr: &'ll Value) -> &'ll Value { @@ -659,7 +666,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } let llval = const_llval.unwrap_or_else(|| { - let load = self.load(llty, place.val.llval, place.val.align); + let load = self.load(llty, place.val.llval, place.val.align, None); if let abi::BackendRepr::Scalar(scalar) = place.layout.backend_repr { scalar_load_metadata(self, load, scalar, place.layout, Size::ZERO); self.to_immediate_scalar(load, scalar) @@ -678,7 +685,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { self.inbounds_ptradd(place.val.llval, self.const_usize(b_offset.bytes())) }; let llty = place.layout.scalar_pair_element_llvm_type(self, i, false); - let load = self.load(llty, llptr, align); + let load = self.load(llty, llptr, align, None); scalar_load_metadata(self, load, scalar, layout, offset); self.to_immediate_scalar(load, scalar) }; @@ -750,8 +757,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value { - self.store_with_flags(val, ptr, align, MemFlags::empty()) + fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let store = self.store_with_flags(val, ptr, align, MemFlags::empty()); + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, store, tt); + } + store } fn store_with_flags( @@ -1050,11 +1061,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1063,7 +1075,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } @@ -1075,11 +1090,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memmove = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -1088,7 +1104,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, memmove, tt); } } @@ -1099,10 +1118,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported"); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memset = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -1110,7 +1130,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, memset, tt); } } @@ -1842,14 +1865,86 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> { #[instrument(level = "debug", skip(self))] pub(crate) fn mcdc_condbitmap_reset(&mut self, mcdc_temp: &'ll Value) { - self.store(self.const_i32(0), mcdc_temp, self.tcx.data_layout.i32_align.abi); + self.store(self.const_i32(0), mcdc_temp, self.tcx.data_layout.i32_align.abi, None); } #[instrument(level = "debug", skip(self))] pub(crate) fn mcdc_condbitmap_update(&mut self, cond_index: &'ll Value, mcdc_temp: &'ll Value) { let align = self.tcx.data_layout.i32_align.abi; - let current_tv_index = self.load(self.cx.type_i32(), mcdc_temp, align); + let current_tv_index = self.load(self.cx.type_i32(), mcdc_temp, align, None); let new_tv_index = self.add(current_tv_index, cond_index); - self.store(new_tv_index, mcdc_temp, align); + self.store(new_tv_index, mcdc_temp, align, None); + } +} + +// Type tree helper functions for autodiff support +fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, tt: FncTree) { + let inputs = tt.inputs; + let _ret: TypeTree = tt.output; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddParamAttr(val, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } +} + +fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { + let inputs = tt.inputs; + let ret_tt: TypeTree = tt.output; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + let ret_attr = unsafe { + let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + attr + }; + unsafe { + llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index a709f528e9fba..4adaf8cbefdcb 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -365,14 +365,7 @@ fn generate_enzyme_call<'ll>( let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); attributes::apply_to_llfn(ad_fn, Function, &[attr]); - // TODO(KMJ-007): Add type tree metadata if available - // This requires adding CreateTypeTreeAttribute to LLVM bindings - // if let Some(tree) = fnc_tree { - // let data_layout = cx.data_layout(); - // let enzyme_tree = to_enzyme_typetree(tree, data_layout, cx.llcx); - // let tt_attr = llvm::CreateTypeTreeAttribute(cx.llcx, enzyme_tree); - // attributes::apply_to_llfn(ad_fn, Function, &[tt_attr]); - // } + // We add a made-up attribute just such that we can recognize it after AD to update // (no)-inline attributes. We'll then also remove this attribute. diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 541cc6ca8718f..b4c26d78b7ce3 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -38,7 +38,7 @@ use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen, TargetConfig}; use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; -use rustc_errors::{DiagCtxtHandle, FatalError}; +use rustc_errors::{DiagCtxt, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; @@ -47,6 +47,12 @@ use rustc_session::Session; use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::Symbol; +#[derive(Clone)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + mod back { pub(crate) mod archive; pub(crate) mod lto; @@ -162,17 +168,9 @@ impl WriteBackendMethods for LlvmCodegenBackend { type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; type TypeTree = DiffTypeTree; - fn print_pass_timings(&self) { - let timings = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintPassTimings(s) }).unwrap(); - print!("{timings}"); - } - fn print_statistics(&self) { - let stats = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintStatistics(s) }).unwrap(); - print!("{stats}"); - } fn run_link( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, + dcx: &DiagCtxt, modules: Vec>, ) -> Result, FatalError> { back::write::link(cgcx, dcx, modules) @@ -191,44 +189,47 @@ impl WriteBackendMethods for LlvmCodegenBackend { ) -> Result<(Vec>, Vec), FatalError> { back::lto::run_thin(cgcx, modules, cached_modules) } - fn optimize( + fn print_pass_timings(&self) { + let timings = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintPassTimings(s) }).unwrap(); + print!("{timings}"); + } + fn print_statistics(&self) { + let stats = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintStatistics(s) }).unwrap(); + print!("{stats}"); + } + unsafe fn optimize( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, - module: &mut ModuleCodegen, + dcx: &DiagCtxt, + module: &ModuleCodegen, config: &ModuleConfig, ) -> Result<(), FatalError> { back::write::optimize(cgcx, dcx, module, config) } fn optimize_fat( cgcx: &CodegenContext, - module: &mut ModuleCodegen, + llmod: &mut ModuleCodegen, ) -> Result<(), FatalError> { - let dcx = cgcx.create_dcx(); - let dcx = dcx.handle(); - back::lto::run_pass_manager(cgcx, dcx, module, false) + back::write::optimize_fat(cgcx, llmod) } - fn optimize_thin( + unsafe fn optimize_thin( cgcx: &CodegenContext, thin: ThinModule, ) -> Result, FatalError> { - back::lto::optimize_thin_module(thin, cgcx) + back::lto::optimize_thin(cgcx, thin) } - fn codegen( + unsafe fn codegen( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, + dcx: &DiagCtxt, module: ModuleCodegen, config: &ModuleConfig, ) -> Result { back::write::codegen(cgcx, dcx, module, config) } - fn prepare_thin( - module: ModuleCodegen, - emit_summary: bool, - ) -> (String, Self::ThinBuffer) { - back::lto::prepare_thin(module, emit_summary) + fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer) { + back::lto::prepare_thin(module) } fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { - (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) + back::lto::serialize_module(module) } /// Generate autodiff rules fn autodiff( @@ -238,17 +239,10 @@ impl WriteBackendMethods for LlvmCodegenBackend { typetrees: FxHashMap, config: &ModuleConfig, ) -> Result<(), FatalError> { - if cgcx.lto != Lto::Fat { - let dcx = cgcx.create_dcx(); - return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); - } builder::autodiff::differentiate(module, cgcx, diff_fncs, typetrees, config) } - - // The typetrees contain all information, their order therefore is irrelevant. - #[allow(rustc::potential_query_instability)] fn typetrees(module: &mut Self::Module) -> FxHashMap { - module.typetrees.drain().collect() + module.typetrees.clone() } } @@ -396,12 +390,6 @@ impl CodegenBackend for LlvmCodegenBackend { } } -#[derive(Clone, Debug)] -pub struct DiffTypeTree { - pub ret_tt: TypeTree, - pub input_tt: Vec, -} - #[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, @@ -447,7 +435,7 @@ impl ModuleLlvm { cgcx: &CodegenContext, name: &CStr, buffer: &[u8], - dcx: DiagCtxtHandle<'_>, + dcx: DiagCtxt<'_>, ) -> Result { unsafe { let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 5e0cc8a0b6316..f378b590c98ee 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2673,4 +2673,289 @@ unsafe extern "C" { // Type Tree Attribute Functions pub fn CreateTypeTreeAttribute<'a>(llcx: &'a Context, typetree: &'a TypeTree) -> &'a Attribute; + + // Enzyme Type Tree Functions + pub fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + + // Enzyme Configuration Functions + pub fn EnzymeSetCLBool(arg1: *mut c_void, arg2: u8); + pub fn EnzymeSetCLInteger(arg1: *mut c_void, arg2: i64); + + // Enzyme Autodiff Functions + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, + _callerCtx: *const u8, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, + _callerCtx: *const u8, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; + + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; + + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +// Type Tree Support for Autodiff +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} + +pub type CTypeTreeRef = *mut EnzymeTypeTree; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = std::ffi::CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl std::fmt::Display for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl std::fmt::Debug for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + pub Arguments: *mut CTypeTreeRef, + pub Return: CTypeTreeRef, + pub KnownValues: *mut IntList, +} + +pub type CustomRuleType = Option< + unsafe extern "C" fn( + direction: c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const c_void, + ) -> u8, +>; + +// Helper functions for type tree support +pub fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::FakeActivitySize(_) => panic!("Implementation error"), + } +} + +pub fn is_size(act: DiffActivity) -> bool { + matches!(act, DiffActivity::FakeActivitySize(_)) +} + +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, } diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ce6fe8a191b3b..9e6964c04a0b2 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -2,7 +2,7 @@ use std::ffi::CString; use std::sync::Arc; use rustc_ast::expand::autodiff_attrs::AutoDiffItem; -use rustc_data_structures::memmap::Mmap; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; use super::write::CodegenContext; @@ -84,11 +84,12 @@ impl LtoModuleCodegen { self, cgcx: &CodegenContext, diff_fncs: Vec, + typetrees: FxHashMap, config: &ModuleConfig, ) -> Result, FatalError> { match &self { LtoModuleCodegen::Fat(module) => { - B::autodiff(cgcx, &module, diff_fncs, config)?; + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; } _ => panic!("autodiff called with non-fat LTO module"), } diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index bde63fd501aa2..43045f859eade 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1552,7 +1552,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { MemFlags::empty(), ); // ...and then load it with the ABI type. - llval = load_cast(bx, cast, llscratch, scratch_align); + llval = bx.load(bx.backend_type(arg.layout), llval, align, None); bx.lifetime_end(llscratch, scratch_size); } else { // We can't use `PlaceRef::load` here because the argument @@ -1560,7 +1560,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // used for this call is passing it by-value. In that case, // the load would just produce `OperandValue::Ref` instead // of the `OperandValue::Immediate` we need for the call. - llval = bx.load(bx.backend_type(arg.layout), llval, align); + llval = bx.load(bx.backend_type(arg.layout), llval, align, None); if let BackendRepr::Scalar(scalar) = arg.layout.backend_repr { if scalar.is_bool() { bx.range_metadata(llval, WrappingRange { start: 0, end: 1 }); @@ -1915,9 +1915,9 @@ fn load_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( assert_eq!(cast.rest.unit.size, cast.rest.total); let first_ty = bx.reg_backend_type(&cast.prefix[0].unwrap()); let second_ty = bx.reg_backend_type(&cast.rest.unit); - let first = bx.load(first_ty, ptr, align); + let first = bx.load(first_ty, ptr, align, None); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); - let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start)); + let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start), None); let res = bx.cx().const_poison(cast_ty); let res = bx.insert_value(res, first, 0); bx.insert_value(res, second, 1) diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index fc95f62b4a43d..29cf5b3dbfa0c 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -13,6 +13,8 @@ use crate::errors::InvalidMonomorphization; use crate::traits::*; use crate::{MemFlags, meth, size_of_val}; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; + fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, allow_overlap: bool, @@ -22,15 +24,20 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( src: Bx::Value, count: Bx::Value, ) { + let fnc_tree: FncTree = FncTree { + args: vec![TypeTree::new(), TypeTree::new(), TypeTree::all_ints()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; if allow_overlap { - bx.memmove(dst, align, src, align, size, flags); + bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree)); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree)); } } @@ -42,12 +49,18 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( val: Bx::Value, count: Bx::Value, ) { + // Create a simple type tree for the memset operation + let fnc_tree: FncTree = FncTree { + args: vec![TypeTree::new(), TypeTree::new(), TypeTree::all_ints()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; - bx.memset(dst, val, size, align, flags); + bx.memset(dst, val, size, align, flags, Some(fnc_tree)); } impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs index 10b44a1faf087..ec22de1ee2969 100644 --- a/compiler/rustc_codegen_ssa/src/mir/mod.rs +++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs @@ -509,7 +509,7 @@ fn arg_local_refs<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let indirect_operand = OperandValue::Pair(llarg, llextra); let tmp = PlaceRef::alloca_unsized_indirect(bx, arg.layout); - indirect_operand.store(bx, tmp); + indirect_operand.store(bx, tmp, None); LocalRef::UnsizedPlace(tmp) } _ => { diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index da615cc9a003d..353d4a439c67f 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -336,7 +336,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { // This is being deprecated, but for now stdarch still needs it for // Newtype vector of array, e.g. #[repr(simd)] struct S([i32; 4]); let place = PlaceRef::alloca(bx, field); - self.val.store(bx, place.val.with_type(self.layout)); + self.val.store(bx, place.val.with_type(self.layout), None); return bx.load_operand(place); } else { // Part of https://github.com/rust-lang/compiler-team/issues/838 @@ -694,8 +694,9 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { self, bx: &mut Bx, dest: PlaceRef<'tcx, V>, + tt: Option, ) { - self.store_with_flags(bx, dest, MemFlags::empty()); + self.store_with_flags(bx, dest, MemFlags::empty(), tt); } pub fn volatile_store>( @@ -703,7 +704,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::VOLATILE); + self.store_with_flags(bx, dest, MemFlags::VOLATILE, None); } pub fn unaligned_volatile_store>( @@ -711,7 +712,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::VOLATILE | MemFlags::UNALIGNED); + self.store_with_flags(bx, dest, MemFlags::VOLATILE | MemFlags::UNALIGNED, None); } pub fn nontemporal_store>( @@ -719,7 +720,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::NONTEMPORAL); + self.store_with_flags(bx, dest, MemFlags::NONTEMPORAL, None); } pub(crate) fn store_with_flags>( @@ -727,6 +728,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, flags: MemFlags, + tt: Option, ) { debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); match self { @@ -743,7 +745,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { } OperandValue::Immediate(s) => { let val = bx.from_immediate(s); - bx.store_with_flags(val, dest.val.llval, dest.val.align, flags); + bx.store(val, dest.val.llval, dest.val.align, tt); } OperandValue::Pair(a, b) => { let BackendRepr::ScalarPair(a_scalar, b_scalar) = dest.layout.backend_repr else { @@ -753,12 +755,12 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { let val = bx.from_immediate(a); let align = dest.val.align; - bx.store_with_flags(val, dest.val.llval, align, flags); + bx.store(val, dest.val.llval, align, tt); let llptr = bx.inbounds_ptradd(dest.val.llval, bx.const_usize(b_offset.bytes())); let val = bx.from_immediate(b); let align = dest.val.align.restrict_for_offset(b_offset); - bx.store_with_flags(val, llptr, align, flags); + bx.store(val, llptr, align, tt); } } } @@ -798,7 +800,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { // Store the allocated region and the extra to the indirect place. let indirect_operand = OperandValue::Pair(dst, llextra); - indirect_operand.store(bx, indirect_dest); + indirect_operand.store(bx, indirect_dest, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/place.rs b/compiler/rustc_codegen_ssa/src/mir/place.rs index 937063c24a63d..a75c008f5cf75 100644 --- a/compiler/rustc_codegen_ssa/src/mir/place.rs +++ b/compiler/rustc_codegen_ssa/src/mir/place.rs @@ -283,7 +283,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { scalar, niche_llty, ); - OperandValue::Immediate(niche_llval).store(bx, niche); + OperandValue::Immediate(niche_llval).store(bx, niche, None); } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 60cf4e28b5a09..4c02cc574ec42 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -29,7 +29,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let cg_operand = self.codegen_operand(bx, operand); // FIXME: consider not copying constants through stack. (Fixable by codegen'ing // constants into `OperandValue::Ref`; why don’t we do that yet if we don’t?) - cg_operand.val.store(bx, dest); + cg_operand.val.store(bx, dest, None); } mir::Rvalue::Cast( @@ -43,7 +43,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // Into-coerce of a thin pointer to a wide pointer -- just // use the operand path. let temp = self.codegen_rvalue_operand(bx, rvalue); - temp.val.store(bx, dest); + temp.val.store(bx, dest, None); return; } @@ -63,7 +63,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { debug!("codegen_rvalue: creating ugly alloca"); let scratch = PlaceRef::alloca(bx, operand.layout); scratch.storage_live(bx); - operand.val.store(bx, scratch); + operand.val.store(bx, scratch, None); base::coerce_unsized_into(bx, scratch, dest); scratch.storage_dead(bx); } @@ -176,7 +176,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } else { variant_dest.project_field(bx, field_index.as_usize()) }; - op.val.store(bx, field); + op.val.store(bx, field, None); } } dest.codegen_set_discr(bx, variant_index); @@ -185,7 +185,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { _ => { assert!(self.rvalue_creates_operand(rvalue, DUMMY_SP)); let temp = self.codegen_rvalue_operand(bx, rvalue); - temp.val.store(bx, dest); + temp.val.store(bx, dest, None); } } } @@ -201,7 +201,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { assert!(dst.layout.is_sized()); if let Some(val) = self.codegen_transmute_operand(bx, src, dst.layout) { - val.store(bx, dst); + val.store(bx, dst, None); return; } @@ -216,7 +216,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { OperandValue::Immediate(..) | OperandValue::Pair(..) => { // When we have immediate(s), the alignment of the source is irrelevant, // so we can store them using the destination's alignment. - src.val.store(bx, dst.val.with_type(src.layout)); + src.val.store(bx, dst.val.with_type(src.layout), None); } } } @@ -454,7 +454,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { ) => { bug!("{kind:?} is for borrowck, and should never appear in codegen"); } - mir::CastKind::PtrToPtr + mir::CastKind::PointerCoercion(PointerCoercion::PtrToPtr if bx.cx().is_backend_scalar_pair(operand.layout) => { if let OperandValue::Pair(data_ptr, meta) = operand.val { diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 9d367748c2a8a..7bdda2a4c0fd7 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,6 +2,7 @@ use std::assert_matches::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::FncTree; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; @@ -226,7 +227,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn alloca(&mut self, size: Size, align: Align) -> Self::Value; fn dynamic_alloca(&mut self, size: Self::Value, align: Align) -> Self::Value; - fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value; + fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align, tt: Option) -> Self::Value; fn volatile_load(&mut self, ty: Self::Type, ptr: Self::Value) -> Self::Value; fn atomic_load( &mut self, @@ -237,7 +238,7 @@ pub trait BuilderMethods<'a, 'tcx>: ) -> Self::Value; fn load_from_place(&mut self, ty: Self::Type, place: PlaceValue) -> Self::Value { assert_eq!(place.llextra, None); - self.load(ty, place.llval, place.align) + self.load(ty, place.llval, place.align, None) } fn load_operand(&mut self, place: PlaceRef<'tcx, Self::Value>) -> OperandRef<'tcx, Self::Value>; @@ -287,10 +288,10 @@ pub trait BuilderMethods<'a, 'tcx>: fn range_metadata(&mut self, load: Self::Value, range: WrappingRange); fn nonnull_metadata(&mut self, load: Self::Value); - fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align) -> Self::Value; + fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align, tt: Option) -> Self::Value; fn store_to_place(&mut self, val: Self::Value, place: PlaceValue) -> Self::Value { assert_eq!(place.llextra, None); - self.store(val, place.llval, place.align) + self.store(val, place.llval, place.align, None) } fn store_with_flags( &mut self, @@ -415,6 +416,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -424,6 +426,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memset( &mut self, @@ -432,6 +435,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + tt: Option, ); /// *Typed* copy for non-overlapping places. @@ -471,7 +475,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); } } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 07a0609fda1a1..e89814b409073 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,23 +1,25 @@ use rustc_ast::expand::autodiff_attrs::AutoDiffItem; -use rustc_errors::{DiagCtxtHandle, FatalError}; +use rustc_data_structures::fx::FxHashMap; +use rustc_errors::{DiagCtxt, FatalError}; use rustc_middle::dep_graph::WorkProduct; use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; -pub trait WriteBackendMethods: Clone + 'static { +pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; type TargetMachine; type TargetMachineError; type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, + dcx: &DiagCtxt, modules: Vec>, ) -> Result, FatalError>; /// Performs fat LTO by merging all modules into a single one and returning it @@ -37,42 +39,41 @@ pub trait WriteBackendMethods: Clone + 'static { ) -> Result<(Vec>, Vec), FatalError>; fn print_pass_timings(&self); fn print_statistics(&self); - fn optimize( + unsafe fn optimize( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, - module: &mut ModuleCodegen, + dcx: &DiagCtxt, + module: &ModuleCodegen, config: &ModuleConfig, ) -> Result<(), FatalError>; fn optimize_fat( cgcx: &CodegenContext, llmod: &mut ModuleCodegen, ) -> Result<(), FatalError>; - fn optimize_thin( + unsafe fn optimize_thin( cgcx: &CodegenContext, thin: ThinModule, ) -> Result, FatalError>; - fn codegen( + unsafe fn codegen( cgcx: &CodegenContext, - dcx: DiagCtxtHandle<'_>, + dcx: &DiagCtxt, module: ModuleCodegen, config: &ModuleConfig, ) -> Result; - fn prepare_thin( - module: ModuleCodegen, - want_summary: bool, - ) -> (String, Self::ThinBuffer); + fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules fn autodiff( cgcx: &CodegenContext, module: &ModuleCodegen, diff_fncs: Vec, + typetrees: FxHashMap, config: &ModuleConfig, ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { fn data(&self) -> &[u8]; - fn thin_link_data(&self) -> &[u8]; } pub trait ModuleBufferMethods: Send + Sync { From d108f800abde876379904f3653d3da1275199685 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sun, 13 Jul 2025 02:33:35 +0000 Subject: [PATCH 3/3] it is not working Signed-off-by: Karan Janthe --- .../rustc_ast/src/expand/autodiff_attrs.rs | 13 ++++-- compiler/rustc_builtin_macros/Cargo.toml | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 5 ++- compiler/rustc_builtin_macros/src/typetree.rs | 29 ++----------- compiler/rustc_codegen_ssa/src/back/lto.rs | 2 +- compiler/rustc_codegen_ssa/src/back/write.rs | 12 +++--- compiler/rustc_codegen_ssa/src/base.rs | 6 +-- compiler/rustc_codegen_ssa/src/meth.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/block.rs | 21 +++++----- .../rustc_codegen_ssa/src/mir/debuginfo.rs | 2 +- .../rustc_codegen_ssa/src/mir/intrinsic.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 5 ++- compiler/rustc_codegen_ssa/src/mir/rvalue.rs | 42 ++++++++++++------- .../rustc_codegen_ssa/src/mir/statement.rs | 2 +- .../rustc_codegen_ssa/src/traits/builder.rs | 4 +- .../src/partitioning/autodiff.rs | 3 +- 16 files changed, 75 insertions(+), 76 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index b615398b4ed09..fff3ef1735ae4 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -116,8 +116,13 @@ impl AutoDiffAttrs { pub fn has_primal_ret(&self) -> bool { matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) } - /// New constructor for type tree support - pub fn into_item(self, source: String, target: String, inputs: Vec, output: TypeTree) -> AutoDiffItem { + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { AutoDiffItem { source, target, attrs: self, inputs, output } } } @@ -284,8 +289,8 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { - AutoDiffItem { source, target, attrs: self } + pub fn into_item_legacy(self, source: String, target: String) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self, inputs: vec![], output: TypeTree::new() } } } diff --git a/compiler/rustc_builtin_macros/Cargo.toml b/compiler/rustc_builtin_macros/Cargo.toml index 4c1264c6f1ce1..a910d2b7f8bd9 100644 --- a/compiler/rustc_builtin_macros/Cargo.toml +++ b/compiler/rustc_builtin_macros/Cargo.toml @@ -22,6 +22,7 @@ rustc_index = { path = "../rustc_index" } rustc_lexer = { path = "../rustc_lexer" } rustc_lint_defs = { path = "../rustc_lint_defs" } rustc_macros = { path = "../rustc_macros" } +rustc_middle = { path = "../rustc_middle" } rustc_parse = { path = "../rustc_parse" } rustc_parse_format = { path = "../rustc_parse_format" } # We must use the proc_macro version that we will compile proc-macros against, diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index dd5f3d5aa3237..a73652561ec08 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -11,7 +11,6 @@ mod llvm_enzyme { AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity, }; - use rustc_ast::expand::typetree::{TypeTree, Type, Kind}; use rustc_ast::ptr::P; use crate::typetree::construct_typetree_from_fnsig; use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; @@ -330,7 +329,7 @@ mod llvm_enzyme { let (inputs, output) = construct_typetree_from_fnsig(&sig); // Use the new into_item method to construct the AutoDiffItem - let autodiff_item = x.clone().into_item( + let _autodiff_item = x.clone().into_item( primal.to_string(), first_ident(&meta_item_vec[0]).to_string(), inputs, @@ -1058,3 +1057,5 @@ mod llvm_enzyme { (d_sig, new_inputs, idents, false) } } + +pub(crate) use crate::autodiff::llvm_enzyme::{expand_forward, expand_reverse}; diff --git a/compiler/rustc_builtin_macros/src/typetree.rs b/compiler/rustc_builtin_macros/src/typetree.rs index f33efe3b22b06..8adf3295c044c 100644 --- a/compiler/rustc_builtin_macros/src/typetree.rs +++ b/compiler/rustc_builtin_macros/src/typetree.rs @@ -6,6 +6,7 @@ use rustc_middle::ty::layout::{FieldsShape, LayoutOf}; use rustc_middle::hir; use rustc_span::Span; use rustc_ast::expand::autodiff_attrs::DiffActivity; +use tracing::trace; #[cfg(llvm_enzyme)] pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { @@ -31,7 +32,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec> = fn_ty.fn_sig(tcx); + let fnc_binder: rustc_middle::ty::Binder<'_, rustc_middle::ty::FnSig<'_>> = fn_ty.fn_sig(tcx); // If rustc compiles the unmodified primal, we know that this copy of the function // also has correct lifetimes. We know that Enzyme won't free the shadow too early @@ -80,7 +81,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec DiffActivity::FakeActivitySize, + => DiffActivity::FakeActivitySize(None), DiffActivity::Const => DiffActivity::Const, _ => panic!("unexpected activity for ptr/ref"), }; @@ -304,27 +305,3 @@ fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, TypeTree::new() } -// AST-based type tree construction (simplified fallback) -#[cfg(llvm_enzyme)] -pub fn construct_typetree_from_ty(ty: &ast::Ty) -> TypeTree { - // For now, return empty type tree to let Enzyme figure out layout - // In a full implementation, we'd need to convert AST types to Ty<'tcx> - // and use the layout-based approach from the old code - TypeTree::new() -} - -#[cfg(llvm_enzyme)] -pub fn construct_typetree_from_fnsig(sig: &ast::FnSig) -> (Vec, TypeTree) { - // For now, return empty type trees - // This will be replaced with proper layout-based construction - let inputs: Vec = sig.decl.inputs.iter() - .map(|_| TypeTree::new()) - .collect(); - - let output = match &sig.decl.output { - FnRetTy::Default(_) => TypeTree::new(), - FnRetTy::Ty(_) => TypeTree::new(), - }; - - (inputs, output) -} diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index 9e6964c04a0b2..a2a25abdb33d0 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -65,7 +65,7 @@ impl LtoModuleCodegen { B::optimize_fat(cgcx, &mut module)?; Ok(module) } - LtoModuleCodegen::Thin(thin) => B::optimize_thin(cgcx, thin), + LtoModuleCodegen::Thin(thin) => unsafe { B::optimize_thin(cgcx, thin) }, } } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 8330e4f7af0c7..d1473488afd0c 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -412,7 +412,7 @@ fn generate_lto_work( B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); if cgcx.lto == Lto::Fat && !autodiff.is_empty() { let config = cgcx.config(ModuleKind::Regular); - module = module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()); + module = module.autodiff(cgcx, autodiff, rustc_data_structures::fx::FxHashMap::default(), config).unwrap_or_else(|e| e.raise()); } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] @@ -870,7 +870,7 @@ fn execute_optimize_work_item( let dcx = cgcx.create_dcx(); let dcx = dcx.handle(); - B::optimize(cgcx, dcx, &mut module, module_config)?; + unsafe { B::optimize(cgcx, &dcx, &mut module, module_config)? }; // After we've done the initial round of optimizations we need to // decide whether to synchronously codegen this module or ship it @@ -891,7 +891,7 @@ fn execute_optimize_work_item( match lto_type { ComputedLtoType::No => finish_intra_module_work(cgcx, module, module_config), ComputedLtoType::Thin => { - let (name, thin_buffer) = B::prepare_thin(module, false); + let (name, thin_buffer) = B::prepare_thin(module); if let Some(path) = bitcode { fs::write(&path, thin_buffer.data()).unwrap_or_else(|e| { panic!("Error writing pre-lto-bitcode file `{}`: {}", path.display(), e); @@ -1014,7 +1014,7 @@ fn finish_intra_module_work( let dcx = dcx.handle(); if !cgcx.opts.unstable_opts.combine_cgu || module.kind == ModuleKind::Allocator { - let module = B::codegen(cgcx, dcx, module, module_config)?; + let module = unsafe { B::codegen(cgcx, &dcx, module, module_config)? }; Ok(WorkItemResult::Finished(module)) } else { Ok(WorkItemResult::NeedsLink(module)) @@ -1700,9 +1700,9 @@ fn start_executing_work( assert!(compiled_modules.is_empty()); let dcx = cgcx.create_dcx(); let dcx = dcx.handle(); - let module = B::run_link(&cgcx, dcx, needs_link).map_err(|_| ())?; + let module = B::run_link(&cgcx, &dcx, needs_link).map_err(|_| ())?; let module = - B::codegen(&cgcx, dcx, module, cgcx.config(ModuleKind::Regular)).map_err(|_| ())?; + unsafe { B::codegen(&cgcx, &dcx, module, cgcx.config(ModuleKind::Regular)) }.map_err(|_| ())?; compiled_modules.push(module); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 102d4ea2fa6cf..024f7e66ec421 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -278,7 +278,7 @@ pub(crate) fn coerce_unsized_into<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( OperandValue::Immediate(base) => unsize_ptr(bx, base, src_ty, dst_ty, None), OperandValue::Ref(..) | OperandValue::ZeroSized => bug!(), }; - OperandValue::Pair(base, info).store(bx, dst); + OperandValue::Pair(base, info).store(bx, dst, None); } (&ty::Adt(def_a, _), &ty::Adt(def_b, _)) => { @@ -581,9 +581,9 @@ fn get_argc_argv<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(bx: &mut Bx) -> (Bx::Va let ptr_align = bx.tcx().data_layout.pointer_align.abi; let arg_argc = bx.const_int(bx.cx().type_isize(), 2); let arg_argv = bx.alloca(2 * ptr_size, ptr_align); - bx.store(param_handle, arg_argv, ptr_align); + bx.store(param_handle, arg_argv, ptr_align, None); let arg_argv_el1 = bx.inbounds_ptradd(arg_argv, bx.const_usize(ptr_size.bytes())); - bx.store(param_system_table, arg_argv_el1, ptr_align); + bx.store(param_system_table, arg_argv_el1, ptr_align, None); (arg_argc, arg_argv) } else if bx.cx().sess().target.main_needs_argc_argv { // Params from native `main()` used as args for rust start function diff --git a/compiler/rustc_codegen_ssa/src/meth.rs b/compiler/rustc_codegen_ssa/src/meth.rs index 3a11ce6befb36..f258dcb9c5859 100644 --- a/compiler/rustc_codegen_ssa/src/meth.rs +++ b/compiler/rustc_codegen_ssa/src/meth.rs @@ -148,7 +148,7 @@ pub(crate) fn load_vtable<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( } let gep = bx.inbounds_ptradd(llvtable, bx.const_usize(vtable_byte_offset)); - let ptr = bx.load(llty, gep, ptr_align); + let ptr = bx.load(llty, gep, ptr_align, None); // VTable loads are invariant. bx.set_invariant_load(ptr); if nonnull { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 43045f859eade..a3842a30d9a7e 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -546,7 +546,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let llslot = match op.val { Immediate(_) | Pair(..) => { let scratch = PlaceRef::alloca(bx, self.fn_abi.ret.layout); - op.val.store(bx, scratch); + op.val.store(bx, scratch, None); scratch.val.llval } Ref(place_val) => { @@ -1077,7 +1077,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { | (&mir::Operand::Constant(_), Ref(PlaceValue { llextra: None, .. })) => { let tmp = PlaceRef::alloca(bx, op.layout); bx.lifetime_start(tmp.val.llval, tmp.layout.size); - op.val.store(bx, tmp); + op.val.store(bx, tmp, None); op.val = Ref(tmp.val); lifetime_ends_after_call.push((tmp.val.llval, tmp.layout.size)); } @@ -1475,13 +1475,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { }; let scratch = PlaceValue::alloca(bx, arg.layout.size, required_align); bx.lifetime_start(scratch.llval, arg.layout.size); - op.val.store(bx, scratch.with_type(arg.layout)); + op.val.store(bx, scratch.with_type(arg.layout), None); lifetime_ends_after_call.push((scratch.llval, arg.layout.size)); (scratch.llval, scratch.align, true) } PassMode::Cast { .. } => { let scratch = PlaceRef::alloca(bx, arg.layout); - op.val.store(bx, scratch); + op.val.store(bx, scratch, None); (scratch.val.llval, scratch.val.align, true) } _ => (op.immediate_or_packed_pair(bx), arg.layout.align.abi, false), @@ -1543,13 +1543,14 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let llscratch = bx.alloca(scratch_size, scratch_align); bx.lifetime_start(llscratch, scratch_size); // ...memcpy the value... - bx.memcpy( + bx.memcpy( llscratch, scratch_align, llval, align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); // ...and then load it with the ABI type. llval = bx.load(bx.backend_type(arg.layout), llval, align, None); @@ -1662,7 +1663,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let slot = self.get_personality_slot(&mut cleanup_bx); slot.storage_live(&mut cleanup_bx); - Pair(exn0, exn1).store(&mut cleanup_bx, slot); + Pair(exn0, exn1).store(&mut cleanup_bx, slot, None); cleanup_bx.br(llbb); cleanup_llbb @@ -1922,7 +1923,7 @@ fn load_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let res = bx.insert_value(res, first, 0); bx.insert_value(res, second, 1) } else { - bx.load(cast_ty, ptr, align) + bx.load(cast_ty, ptr, align, None) } } @@ -1939,10 +1940,10 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( assert!(cast.prefix[0].is_some()); let first = bx.extract_value(value, 0); let second = bx.extract_value(value, 1); - bx.store(first, ptr, align); + bx.store(first, ptr, align, None); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); - bx.store(second, second_ptr, align.restrict_for_offset(offset_from_start)); + bx.store(second, second_ptr, align.restrict_for_offset(offset_from_start), None); } else { - bx.store(value, ptr, align); + bx.store(value, ptr, align, None); }; } diff --git a/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs b/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs index 025f5fb54f428..529090ba3960f 100644 --- a/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs +++ b/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs @@ -249,7 +249,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { if let Some(name) = name { bx.set_var_name(spill_slot.val.llval, &(name + ".dbg.spill")); } - operand.val.store(bx, spill_slot); + operand.val.store(bx, spill_slot, None); spill_slot } diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index 29cf5b3dbfa0c..52a29337b0fa7 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -140,7 +140,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { sym::caller_location => { let location = self.get_caller_location(bx, source_info); - location.val.store(bx, result); + location.val.store(bx, result, None); return Ok(()); } diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index 353d4a439c67f..ef2d3125250f8 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -18,6 +18,7 @@ use super::{FunctionCx, LocalRef}; use crate::common::IntPredicate; use crate::traits::*; use crate::{MemFlags, size_of_val}; +use rustc_ast::expand::typetree::FncTree; /// The representation of a Rust value. The enum variant is in fact /// uniquely determined by the value's type, but is kept as a @@ -755,7 +756,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { let val = bx.from_immediate(a); let align = dest.val.align; - bx.store(val, dest.val.llval, align, tt); + bx.store(val, dest.val.llval, align, tt.clone()); let llptr = bx.inbounds_ptradd(dest.val.llval, bx.const_usize(b_offset.bytes())); let val = bx.from_immediate(b); @@ -796,7 +797,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { let neg_address = bx.neg(address); let offset = bx.and(neg_address, align_minus_1); let dst = bx.inbounds_ptradd(alloca, offset); - bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty()); + bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty(), None); // Store the allocated region and the extra to the indirect place. let indirect_operand = OperandValue::Pair(dst, llextra); diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 4c02cc574ec42..6752dcc8bb82e 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -102,6 +102,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { size, dest.val.align, MemFlags::empty(), + None, ); return; } @@ -119,7 +120,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let first = bytes[0]; if bytes[1..].iter().all(|&b| b == first) { let fill = bx.cx().const_u8(first); - bx.memset(start, fill, size, dest.val.align, MemFlags::empty()); + bx.memset(start, fill, size, dest.val.align, MemFlags::empty(), None); return true; } } @@ -127,7 +128,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // Use llvm.memset.p0i8.* to initialize byte arrays let v = bx.from_immediate(v); if bx.cx().val_ty(v) == bx.cx().type_i8() { - bx.memset(start, v, size, dest.val.align, MemFlags::empty()); + bx.memset(start, v, size, dest.val.align, MemFlags::empty(), None); return true; } false @@ -454,21 +455,32 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { ) => { bug!("{kind:?} is for borrowck, and should never appear in codegen"); } - mir::CastKind::PointerCoercion(PointerCoercion::PtrToPtr - if bx.cx().is_backend_scalar_pair(operand.layout) => - { - if let OperandValue::Pair(data_ptr, meta) = operand.val { - if bx.cx().is_backend_scalar_pair(cast) { - OperandValue::Pair(data_ptr, meta) - } else { - // Cast of wide-ptr to thin-ptr is an extraction of data-ptr. - OperandValue::Immediate(data_ptr) - } - } else { - bug!("unexpected non-pair operand"); + mir::CastKind::PointerCoercion(_, _) => { + let imm = operand.immediate(); + let operand_kind = self.value_kind(operand.layout); + let OperandValueKind::Immediate(from_scalar) = operand_kind else { + bug!("Found {operand_kind:?} for operand {operand:?}"); + }; + let from_backend_ty = bx.cx().immediate_backend_type(operand.layout); + + assert!(bx.cx().is_backend_immediate(cast)); + let to_backend_ty = bx.cx().immediate_backend_type(cast); + if operand.layout.is_uninhabited() { + let val = OperandValue::Immediate(bx.cx().const_poison(to_backend_ty)); + return OperandRef { val, layout: cast }; } + let cast_kind = self.value_kind(cast); + let OperandValueKind::Immediate(to_scalar) = cast_kind else { + bug!("Found {cast_kind:?} for operand {cast:?}"); + }; + + self.cast_immediate(bx, imm, from_scalar, from_backend_ty, to_scalar, to_backend_ty) + .map(OperandValue::Immediate) + .unwrap_or_else(|| { + bug!("Unsupported cast of {operand:?} to {cast:?}"); + }) } - | mir::CastKind::IntToInt + mir::CastKind::IntToInt | mir::CastKind::FloatToInt | mir::CastKind::FloatToFloat | mir::CastKind::IntToFloat diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index cd55a838a7561..f213212d7065d 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -85,7 +85,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 7bdda2a4c0fd7..cb6be64031b77 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -472,7 +472,7 @@ pub trait BuilderMethods<'a, 'tcx>: // If we're not optimizing, the aliasing information from `memcpy` // isn't useful, so just load-store the value for smaller code. let temp = self.load_operand(src.with_type(layout)); - temp.val.store_with_flags(self, dst.with_type(layout), flags); + temp.val.store_with_flags(self, dst.with_type(layout), flags, None); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); @@ -500,7 +500,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp = self.load_operand(alloca); } self.typed_place_copy(left, right, layout); - temp.val.store(self, right.with_type(layout)); + temp.val.store(self, right.with_type(layout), None); } fn select( diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 22d593b80b895..c59cf34da5313 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -1,4 +1,5 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; +use rustc_ast::expand::typetree::TypeTree; use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::bug; use rustc_middle::mir::mono::MonoItem; @@ -128,7 +129,7 @@ pub(crate) fn find_autodiff_source_functions<'tcx>( let mut new_target_attrs = target_attrs.clone(); new_target_attrs.input_activity = input_activities; - let itm = new_target_attrs.into_item(symb, target_symbol); + let itm = new_target_attrs.into_item(symb, target_symbol, Vec::new(), TypeTree::new()); autodiff_items.push(itm); }