Skip to content

feat: implement several direct hooks into luaC API #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rivets-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ darling = "0.20"

[lib]
proc-macro = true
path = "src/macros.rs"
75 changes: 34 additions & 41 deletions rivets-macros/src/lib.rs → rivets-macros/src/macros.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
#![warn(missing_docs)]
#![feature(proc_macro_diagnostic)]
#![warn(missing_docs)]

//! Contains the proc macros for rivets.

use anyhow::{bail, Result};
use darling::FromDeriveInput;
use lazy_regex::regex;
use proc_macro::{self, Diagnostic, Level, Span, TokenStream};
use proc_macro::{Diagnostic, Level, Span, TokenStream};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::sync::{atomic::AtomicBool, LazyLock, Mutex};
use syn::{parse_macro_input, Abi, DeriveInput, Error, Expr, FnArg, Ident, ItemFn, Variant};
use syn::{parse_macro_input, Abi, DeriveInput, Expr, FnArg, Ident, ItemFn, Variant};

static IS_FINALIZED: AtomicBool = AtomicBool::new(false);
static MANGLED_NAMES: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));
static CPP_IMPORTS: LazyLock<Mutex<Vec<(String, String)>>> = LazyLock::new(|| Mutex::new(vec![]));

macro_rules! derive_error {
($string: tt) => {
Error::new(proc_macro2::Span::call_site(), $string)
.to_compile_error()
.into()
};
}

macro_rules! check_finalized {
() => {
// this check causes issues with rust-analyer. disable during debug builds.
#[cfg(not(debug_assertions))]
if IS_FINALIZED.load(std::sync::atomic::Ordering::Relaxed) {
panic!("The rivets library has already been finalized!");
}
};
fn derive_error(error_message: &str) -> TokenStream {
Diagnostic::spanned(Span::call_site(), Level::Error, error_message).emit();
quote! {}.into()
}

fn failure(callback: proc_macro2::TokenStream, error_message: &str) -> TokenStream {
Diagnostic::spanned(Span::call_site(), Level::Error, error_message).emit();
callback.into()
/// Asserts if the rivets library has already been finalized.
/// This check is preformed in the following proc macros:
/// - `detour`
/// - `import`
/// - `finalize`
///
/// Note that this check is only preformed in release builds.
fn check_finalized() {
// this check causes issues with rust-analyer. disable during debug builds.
#[cfg(not(debug_assertions))]
if IS_FINALIZED.load(std::sync::atomic::Ordering::Relaxed) {
panic!("The rivets library has already been finalized!");
}
}

fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result<Abi> {
Expand All @@ -47,7 +44,7 @@ fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result<
let abi = regex!(r" __[a-zA-Z]+ ").find(unmangled_name);
let abi = match abi {
Some(abi) => abi.as_str(),
None => bail!("Failed to automatically determine calling convention for {unmangled_name}. Try specifying the calling convention manually. Example: extern \"C\" fn() {}", "{}"),
None => bail!("Failed to automatically determine calling convention for {unmangled_name}. Try specifying the calling convention manually. Example: extern \"C-unwind\" fn() {}", "{}"),
};
let abi = &abi[1..abi.len() - 1];
if let Some(calling_convention) = rivets_shared::get_calling_convention(abi) {
Expand Down Expand Up @@ -103,7 +100,7 @@ fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result<
/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions.
#[proc_macro_attribute]
pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();
check_finalized();

let mangled_name = attr.to_string();
let unmangled_name =
Expand Down Expand Up @@ -134,8 +131,8 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
let arg_names = quote! { #( #arg_names ),* };

let calling_convention = match determine_calling_convention(&input, &unmangled_name) {
Ok(calling_convention) => calling_convention,
Err(e) => return failure(quote! { #input }, &e.to_string()),
Ok(calling_convention) => Some(calling_convention),
Err(e) => return derive_error(&e.to_string()),
};
input.sig.abi = None;
let callback = quote! { #input };
Expand All @@ -152,12 +149,11 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
static Detour : #cpp_function_header;
}

#[doc = #unmangled_name]
unsafe fn back(#inputs) #return_type {
Detour.call(#arg_names)
}

#[doc = #unmangled_name]
#[allow(unused_variables)]
#callback

pub unsafe fn hook(address: u64) -> Result<(), rivets::retour::Error> {
Expand All @@ -173,8 +169,6 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
.expect("Failed to lock mangled names")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

result.into()
}

Expand Down Expand Up @@ -216,7 +210,7 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream {
/// Calling any imported function repersents calling into the C++ compiled codebase and thus is inherently unsafe.
#[proc_macro_attribute]
pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream {
check_finalized!();
check_finalized();

let mangled_name = attr.to_string();
let unmangled_name =
Expand All @@ -226,7 +220,7 @@ pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream {

let calling_convention = match determine_calling_convention(&input, &unmangled_name) {
Ok(calling_convention) => Some(calling_convention),
Err(e) => return failure(quote! { #input }, &e.to_string()),
Err(e) => return derive_error(&e.to_string()),
};

let arg_types = input.sig.inputs.iter().map(|arg| match arg {
Expand All @@ -252,12 +246,11 @@ pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream {
.expect("Failed to lock cpp imports")
.push((mangled_name.clone(), name.to_string()));

Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit();

quote! {
#[doc = #unmangled_name]
#[allow(non_upper_case_globals)]
#[allow(missing_docs)]
#attr #vis static mut #name: rivets::UnsafeSummonedFunction<#function_type> = rivets::UnsafeSummonedFunction::Uninitialized;
#attr #vis static mut #name: rivets::UnsafeImportedFunction<#function_type> = rivets::UnsafeImportedFunction::Uninitialized;
}.into()
}

Expand Down Expand Up @@ -295,7 +288,7 @@ fn get_imports() -> Vec<proc_macro2::TokenStream> {
let function = unsafe {
std::mem::transmute(address) // todo: rust documentation recommends casting this to a raw function pointer. address as *const _
};
unsafe { #rust_name = rivets::UnsafeSummonedFunction::Function(function); }
unsafe { #rust_name = rivets::UnsafeImportedFunction::Function(function); }
}
})
.collect()
Expand All @@ -306,7 +299,7 @@ fn get_imports() -> Vec<proc_macro2::TokenStream> {
/// It will finalize the rivets library and inject all of the detours.
#[proc_macro]
pub fn finalize(_: TokenStream) -> TokenStream {
check_finalized!();
check_finalized();
IS_FINALIZED.store(true, std::sync::atomic::Ordering::Relaxed);

let hooks = get_hooks();
Expand Down Expand Up @@ -347,12 +340,12 @@ struct DefineOpts {
pub fn define_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input);
let Ok(DefineOpts { kind }) = DefineOpts::from_derive_input(&input) else {
return derive_error!("Missing #[kind(?)] attribute!");
return derive_error("Missing #[kind(?)] attribute!");
};
let DeriveInput { ident, data, .. } = input;

let syn::Data::Enum(data) = data else {
return derive_error!("FactorioDefine can only be used on enums!");
return derive_error("FactorioDefine can only be used on enums!");
};

let count = data.variants.len();
Expand All @@ -372,7 +365,7 @@ pub fn define_derive(input: TokenStream) -> TokenStream {
}

let Expr::Lit(syn::PatLit { lit, .. }) = &nv.value else {
return derive_error!("All variants must have a #[value(?)] attribute!");
return derive_error("All variants must have a #[value(?)] attribute!");
};

value = Some(lit.clone());
Expand All @@ -381,7 +374,7 @@ pub fn define_derive(input: TokenStream) -> TokenStream {
}

let Some(value) = value else {
return derive_error!("All variants must have a #[value = ?] attribute!");
return derive_error("All variants must have a #[value = ?] attribute!");
};

let Variant { ident, .. } = variant;
Expand Down
6 changes: 3 additions & 3 deletions rivets-shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,23 @@ impl SymbolCache {
/// Represents a function that has been imported from a C++ compiled DLL.
/// Invariant: If the function is not initialized, it is UB to dereference it.
/// The rivets::finalize!() macro should be used to ensure that the function is initialized.
pub enum UnsafeSummonedFunction<T>
pub enum UnsafeImportedFunction<T>
where
T: 'static + Sized,
{
Function(T),
Uninitialized,
}

impl<T> Deref for UnsafeSummonedFunction<T> {
impl<T> Deref for UnsafeImportedFunction<T> {
type Target = T;

#[inline]
#[track_caller]
fn deref(&self) -> &Self::Target {
match self {
Self::Function(x) => x,
Self::Uninitialized => unsafe { std::hint::unreachable_unchecked() },
Self::Uninitialized => unreachable!("Attempted to dereference an uninitialized imported function pointer! This is a bug in rivets core, please report it."),
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![feature(c_size_t)]

pub mod defines;
pub mod lua;
Expand Down
Loading