diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index cf5badf7b99c1..f28af4ba49bad 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -616,7 +616,9 @@ pub(crate) fn run_pass_manager( } let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?; + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; } debug!("lto done"); Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 2bbe1f03f91d8..e09aec69ec44c 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -549,6 +549,7 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + first_run: bool, ) -> Result<(), FatalError> { // Enzyme: // We want to simplify / optimize functions before AD. @@ -556,16 +557,23 @@ pub(crate) unsafe fn llvm_optimize( // tend to reduce AD performance. Therefore activate them first, then differentiate the code // and finally re-optimize the module, now with all optimizations available. // RIP compile time. - // let unroll_loops = - // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops; + let vectorize_slp; + let vectorize_loop; - - let _unroll_loops = + if first_run { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; - let unroll_loops = false; - let vectorize_slp = false; - let vectorize_loop = false; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + dbg!("Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", unroll_loops, vectorize_slp, vectorize_loop); + } + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -987,11 +995,13 @@ pub(crate) unsafe fn enzyme_ad( item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); - let mut opt = 1; - if std::env::var("ENZYME_DISABLE_OPTS").is_ok() { - opt = 0; + let mut fnc_opt = false; + if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { + dbg!("Disabling optimizations for Enzyme"); + fnc_opt = true; } - let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt as u8); let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); @@ -1051,7 +1061,7 @@ pub(crate) unsafe fn differentiate( cgcx: &CodegenContext, diff_items: Vec, _typetrees: FxHashMap, - _config: &ModuleConfig, + config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { trace!("{}", item); @@ -1086,6 +1096,7 @@ pub(crate) unsafe fn differentiate( llvm::set_max_type_offset(width); } + let differentiate = !diff_items.is_empty(); for item in diff_items { let res = enzyme_ad(llmod, llcx, &diag_handler, item); assert!(res.is_ok()); @@ -1126,6 +1137,23 @@ pub(crate) unsafe fn differentiate( } } + + if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate { + trace!("Skipping module optimization after automatic differentiation"); + } else { + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let first_run = false; + dbg!("Running Module Optimization after differentiation"); + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run)?; + } + } + Ok(()) } @@ -1198,7 +1226,9 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage); + // Second run only relevant for AD + let first_run = true; + return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run); } Ok(()) }