Skip to content

Commit

Permalink
Merge #231: types: use slab allocator for type bounds
Browse files Browse the repository at this point in the history
c22dbaa types: drop BoundMutex and instead use references into the type context slab (Andrew Poelstra)
2e24c49 types: pull unify and bind into inference context (Andrew Poelstra)
a26cf7a types: remove set and get methods from BoundRef (Andrew Poelstra)
33f58fa types: introduce BoundRef type, use in place of Arc<BoundMutex> in union-bound (Andrew Poelstra)
021316c types: abstract pointer type in union-bound algorithm (Andrew Poelstra)
eccc332 types: add &Context to recursive type constructors (Andrew Poelstra)
65b35a9 types: add &Context to type constructors (Andrew Poelstra)
8e08900 types: make `bind` and `unify` go through Context (Andrew Poelstra)
8eeab8f types: introduce inference context object, thread it through the API (Andrew Poelstra)
9b0790e cmr: pull Constructible impl on Cmr into an impl on an auxiliary type (Andrew Poelstra)

Pull request description:

  Our existing type inference engine assumes a "global" set of type bounds, which has two bad consequences: one is that if you are constructing multiple programs, there is no way to "firewall" their type bounds so that you cannot accidentally combine type variables from one program with type variables from another. You just need to be careful. The other consequence is that if you construct infinitely sized types, which are represented as a reference cycle, the existing inference engine will leak memory.

  To fix this, we need to stop allocating type bounds using untethered `Arc`s and instead use a slab allocator, which allows all bounds to be dropped at once, regardless of their circularity. This should also improve memory locality and our speed, as well as reducing the total amount of locking and potential mutex contention if type inference is done in a multithreaded context.

  This is a 2000-line diff but the vast majority of the changes are "API-only" stuff where I was moving types around and threading new parameters through dozens or hundreds of call sites. I did my best to break everything up into commits such that the big-diff commits don't do much of anything and the real changes happen in the small-diff ones to make review easier.

  By itself, this PR does **not** fix the issue of reference cycles, because it includes an `Arc<Context>` inside the recursive `Type` type itself. Future PRs will:

  * Take a single mutex lock during calls to the top-level `bind` and `unify` calls, so that these all happen atomically, including all recursive calls.
  * Add another intermediate type under `Type` which eliminates the `Arc<Context>` and its potential for circular references. Along the way, make the `Bound` type private, which is not really used outside of the types module anyway.
  * Do "checkpointing" during type inference that makes node construction atomic; this is #226 which is **not** fixed by this PR.
  * (Maybe) move node allocation into the type inference context so that nodes can be slab-allocated as well, which will address #229 "for free" without us figuring out a non-recursive `Drop` impl for `Arc<Node<N>>`.

ACKs for top commit:
  uncomputable:
    ACK c22dbaa

Tree-SHA512: 0fd2fdd9fe3634068d67279d517573df04fafa60b70e432f59417880982ad22e893822362973f946f1deb6279080aec1efdd942dfd8adad81bbddc7d55077336
  • Loading branch information
uncomputable committed Jul 4, 2024
2 parents b55b8e7 + c22dbaa commit 8c17b94
Show file tree
Hide file tree
Showing 25 changed files with 1,144 additions and 595 deletions.
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
use bitcoin::secp256k1;
use elements::Txid;
use rand::{thread_rng, RngCore};
pub use simplicity::hashes::sha256;
use simplicity::{
bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value,
bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value,
};
use std::sync::Arc;

Expand Down Expand Up @@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Arc<Value>, Erro
assert!(n < 16);
assert!(v.len() < (1 << (n + 1)));
let mut iter = BitIter::new(v.iter().copied());
let types = Type::powers_of_two(n); // size n + 1
let ctx = types::Context::new();
let types = Type::powers_of_two(&ctx, n); // size n + 1
let mut res = None;
while n > 0 {
let v = if v.len() >= (1 << (n + 1)) {
Expand Down
6 changes: 3 additions & 3 deletions jets-bench/benches/elements/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ impl ElementsBenchEnvType {
}

fn jet_arrow(jet: Elements) -> (Arc<types::Final>, Arc<types::Final>) {
let src_ty = jet.source_ty().to_type().final_data().unwrap();
let tgt_ty = jet.target_ty().to_type().final_data().unwrap();
let src_ty = jet.source_ty().to_final();
let tgt_ty = jet.target_ty().to_final();
(src_ty, tgt_ty)
}

Expand Down Expand Up @@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) {
let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng());
let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair);

let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap();
let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap();
let sig = secp_ctx.sign_schnorr(&msg, &keypair);
let xpk_value = Value::u256_from_slice(&xpk.0.serialize());
let sig_value = Value::u512_from_slice(sig.as_ref());
Expand Down
3 changes: 2 additions & 1 deletion src/bit_encoding/bitwriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ mod tests {
use super::*;
use crate::jet::Core;
use crate::node::CoreConstructible;
use crate::types;
use crate::ConstructNode;
use std::sync::Arc;

#[test]
fn vec() {
let program = Arc::<ConstructNode<Core>>::unit();
let program = Arc::<ConstructNode<Core>>::unit(&types::Context::new());
let _ = write_to_vec(|w| program.encode(w));
}

Expand Down
14 changes: 8 additions & 6 deletions src/bit_encoding/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::node::{
ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness,
WitnessConstructible,
};
use crate::types;
use crate::{BitIter, FailEntropy, Value};
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
return Err(Error::TooManyNodes(len));
}

let inference_context = types::Context::new();
let mut nodes = Vec::with_capacity(len);
for _ in 0..len {
let new_node = decode_node(bits, nodes.len())?;
Expand All @@ -195,8 +197,8 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
}

let new = match nodes[data.node.0] {
DecodeNode::Unit => Node(ArcNode::unit()),
DecodeNode::Iden => Node(ArcNode::iden()),
DecodeNode::Unit => Node(ArcNode::unit(&inference_context)),
DecodeNode::Iden => Node(ArcNode::iden(&inference_context)),
DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)),
DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)),
DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)),
Expand All @@ -222,16 +224,16 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
converted[i].get()?,
&Some(Arc::clone(converted[j].get()?)),
)?),
DecodeNode::Witness => Node(ArcNode::witness(NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)),
DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)),
DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)),
DecodeNode::Hidden(cmr) => {
if !hidden_set.insert(cmr) {
return Err(Error::SharingNotMaximal);
}
Hidden(cmr)
}
DecodeNode::Jet(j) => Node(ArcNode::jet(j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))),
DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)),
DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))),
};
converted.push(new);
}
Expand Down
74 changes: 47 additions & 27 deletions src/human_encoding/named_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::node::{
self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness,
WitnessData,
};
use crate::node::{Construct, ConstructData, Constructible};
use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _};
use crate::types;
use crate::types::arrow::{Arrow, FinalArrow};
use crate::{encode, Value, WitnessNode};
Expand Down Expand Up @@ -116,6 +116,7 @@ impl<J: Jet> NamedCommitNode<J> {
struct Populator<'a, J: Jet> {
witness_map: &'a HashMap<Arc<str>, Arc<Value>>,
disconnect_map: &'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
inference_context: types::Context,
phantom: PhantomData<J>,
}

Expand Down Expand Up @@ -153,17 +154,16 @@ impl<J: Jet> NamedCommitNode<J> {
// Like witness nodes (see above), disconnect nodes may be pruned later.
// The finalization will detect missing branches and throw an error.
let maybe_commit = self.disconnect_map.get(hole_name);
// FIXME: Recursive call of to_witness_node
// We cannot introduce a stack
// because we are implementing methods of the trait Converter
// which are used Marker::convert().
// FIXME: recursive call to convert
// We cannot introduce a stack because we are implementing the Converter
// trait and do not have access to the actual algorithm used for conversion
// in order to save its state.
//
// OTOH, if a user writes a program with so many disconnected expressions
// that there is a stack overflow, it's his own fault :)
// This would fail in a fuzz test.
let witness = maybe_commit.map(|commit| {
commit.to_witness_node(self.witness_map, self.disconnect_map)
});
// This may fail in a fuzz test.
let witness = maybe_commit
.map(|commit| commit.convert::<InternalSharing, _, _>(self).unwrap());
Ok(witness)
}
}
Expand All @@ -181,13 +181,15 @@ impl<J: Jet> NamedCommitNode<J> {
let inner = inner
.map(|node| node.cached_data())
.map_witness(|maybe_value| maybe_value.clone());
Ok(WitnessData::from_inner(inner).expect("types are already finalized"))
Ok(WitnessData::from_inner(&self.inference_context, inner)
.expect("types are already finalized"))
}
}

self.convert::<InternalSharing, _, _>(&mut Populator {
witness_map: witness,
disconnect_map: disconnect,
inference_context: types::Context::new(),
phantom: PhantomData,
})
.unwrap()
Expand Down Expand Up @@ -245,13 +247,15 @@ pub struct NamedConstructData<J> {
impl<J: Jet> NamedConstructNode<J> {
/// Construct a named construct node from parts.
pub fn new(
inference_context: &types::Context,
name: Arc<str>,
position: Position,
user_source_types: Arc<[types::Type]>,
user_target_types: Arc<[types::Type]>,
inner: node::Inner<Arc<Self>, J, Arc<Self>, WitnessOrHole>,
) -> Result<Self, types::Error> {
let construct_data = ConstructData::from_inner(
inference_context,
inner
.as_ref()
.map(|data| &data.cached_data().internal)
Expand Down Expand Up @@ -295,6 +299,11 @@ impl<J: Jet> NamedConstructNode<J> {
self.cached_data().internal.arrow()
}

/// Accessor for the node's type inference context.
pub fn inference_context(&self) -> &types::Context {
self.cached_data().internal.inference_context()
}

/// Finalizes the types of the underlying [`crate::ConstructNode`].
pub fn finalize_types_main(&self) -> Result<Arc<NamedCommitNode<J>>, ErrorSet> {
self.finalize_types_inner(true)
Expand Down Expand Up @@ -386,17 +395,23 @@ impl<J: Jet> NamedConstructNode<J> {
.map_disconnect(|_| &NoDisconnect)
.copy_witness();

let ctx = data.node.inference_context();

if !self.for_main {
// For non-`main` fragments, treat the ascriptions as normative, and apply them
// before finalizing the type.
let arrow = data.node.arrow();
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = arrow.source.unify(ty, "binding source type annotation") {
if let Err(e) =
ctx.unify(&arrow.source, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = arrow.target.unify(ty, "binding target type annotation") {
if let Err(e) =
ctx.unify(&arrow.target, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -413,15 +428,19 @@ impl<J: Jet> NamedConstructNode<J> {
if self.for_main {
// For `main`, only apply type ascriptions *after* inference has completely
// determined the type.
let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source));
let source_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source));
for ty in data.node.cached_data().user_source_types.as_ref() {
if let Err(e) = source_ty.unify(ty, "binding source type annotation") {
if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation")
{
self.errors.add(data.node.position(), e);
}
}
let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target));
let target_ty =
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target));
for ty in data.node.cached_data().user_target_types.as_ref() {
if let Err(e) = target_ty.unify(ty, "binding target type annotation") {
if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation")
{
self.errors.add(data.node.position(), e);
}
}
Expand All @@ -442,22 +461,23 @@ impl<J: Jet> NamedConstructNode<J> {
};

if for_main {
let unit_ty = types::Type::unit();
let ctx = self.inference_context();
let unit_ty = types::Type::unit(ctx);
if self.cached_data().user_source_types.is_empty() {
if let Err(e) = self
.arrow()
.source
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().source,
&unit_ty,
"setting root source to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
if self.cached_data().user_target_types.is_empty() {
if let Err(e) = self
.arrow()
.target
.unify(&unit_ty, "setting root source to unit")
{
if let Err(e) = ctx.unify(
&self.arrow().target,
&unit_ty,
"setting root target to unit",
) {
finalizer.errors.add(self.position(), e);
}
}
Expand Down
25 changes: 18 additions & 7 deletions src/human_encoding/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,25 @@ pub enum Type {

impl Type {
/// Convert to a Simplicity type
pub fn reify(self) -> types::Type {
pub fn reify(self, ctx: &types::Context) -> types::Type {
match self {
Type::Name(s) => types::Type::free(s),
Type::One => types::Type::unit(),
Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()),
Type::Product(left, right) => types::Type::product(left.reify(), right.reify()),
Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()),
Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers
Type::Name(s) => types::Type::free(ctx, s),
Type::One => types::Type::unit(ctx),
Type::Two => {
let unit_ty = types::Type::unit(ctx);
types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty)
}
Type::Product(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::product(ctx, left, right)
}
Type::Sum(left, right) => {
let left = left.reify(ctx);
let right = right.reify(ctx);
types::Type::sum(ctx, left, right)
}
Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/human_encoding/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod ast;
use crate::dag::{Dag, DagLike, InternalSharing};
use crate::jet::Jet;
use crate::node;
use crate::types::Type;
use crate::types::{self, Type};
use std::collections::HashMap;
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -181,6 +181,7 @@ pub fn parse<J: Jet + 'static>(
program: &str,
) -> Result<HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, ErrorSet> {
let mut errors = ErrorSet::new();
let inference_context = types::Context::new();
// **
// Step 1: Read expressions into HashMap, checking for dupes and illegal names.
// **
Expand All @@ -205,10 +206,10 @@ pub fn parse<J: Jet + 'static>(
}
}
if let Some(ty) = line.arrow.0 {
entry.add_source_type(ty.reify());
entry.add_source_type(ty.reify(&inference_context));
}
if let Some(ty) = line.arrow.1 {
entry.add_target_type(ty.reify());
entry.add_target_type(ty.reify(&inference_context));
}
}

Expand Down Expand Up @@ -485,6 +486,7 @@ pub fn parse<J: Jet + 'static>(
.unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str()));

let node = NamedConstructNode::new(
&inference_context,
Arc::clone(&name),
data.node.position,
Arc::clone(&data.node.user_source_types),
Expand Down
3 changes: 2 additions & 1 deletion src/jet/elements/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use crate::jet::elements::{ElementsEnv, ElementsUtxo};
use crate::jet::Elements;
use crate::node::{ConstructNode, JetConstructible};
use crate::types;
use crate::{BitMachine, Cmr, Value};
use elements::secp256k1_zkp::Tweak;
use elements::taproot::ControlBlock;
Expand Down Expand Up @@ -99,7 +100,7 @@ fn test_ffi_env() {
BlockHash::all_zeros(),
);

let prog = Arc::<ConstructNode<_>>::jet(Elements::LockTime);
let prog = Arc::<ConstructNode<_>>::jet(&types::Context::new(), Elements::LockTime);
assert_eq!(
BitMachine::test_exec(prog, &env).expect("executing"),
Value::u32(100),
Expand Down
Loading

0 comments on commit 8c17b94

Please sign in to comment.