diff --git a/Cargo.lock b/Cargo.lock index 8796c7c..5d504cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2537,6 +2537,8 @@ dependencies = [ "iroh", "iroh-base", "iroh-blobs", + "p3-field", + "p3-goldilocks", "proptest", "rand 0.9.1", "rayon", @@ -3112,6 +3114,122 @@ dependencies = [ "sha2", ] +[[package]] +name = "p3-dft" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b2764a3982d22d62aa933c8de6f9d71d8a474c9110b69e675dea1887bdeffc" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "tracing", +] + +[[package]] +name = "p3-field" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc13a73509fe09c67b339951ca8d4cc6e61c9bf08c130dbc90dda52452918cc2" +dependencies = [ + "itertools 0.14.0", + "num-bigint", + "p3-maybe-rayon", + "p3-util", + "paste", + "rand 0.9.1", + "serde", + "tracing", +] + +[[package]] +name = "p3-goldilocks" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552849f6309ffde34af0d31aa9a2d0a549cb0ec138d9792bfbf4a17800742362" +dependencies = [ + "num-bigint", + "p3-dft", + "p3-field", + "p3-mds", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "paste", + "rand 0.9.1", + "serde", +] + +[[package]] +name = "p3-matrix" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8e1e9f69c2fe15768b3ceb2915edb88c47398aa22c485d8163deab2a47fe194" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-maybe-rayon", + "p3-util", + "rand 0.9.1", + "serde", + "tracing", + "transpose", +] + +[[package]] +name = "p3-maybe-rayon" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33f765046b763d046728b3246b690f81dfa7ccd7523b7a1582c74f616fbce6a0" + +[[package]] +name = "p3-mds" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c90541c6056712daf2ee69ec328db8b5605ae8dbafe60226c8eb75eaac0e1f9" +dependencies = [ + "p3-dft", + "p3-field", + "p3-symmetric", + "p3-util", + "rand 0.9.1", +] + +[[package]] +name = "p3-poseidon2" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88e9f053f120a78ad27e9c1991a0ea547777328ca24025c42364d6ee2667d59a" +dependencies = [ + "p3-field", + "p3-mds", + "p3-symmetric", + "p3-util", + "rand 0.9.1", +] + +[[package]] +name = "p3-symmetric" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d5db8f05a26d706dfd8aaf7aa4272ca4f3e7a075db897ec7108f24fad78759" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "serde", +] + +[[package]] +name = "p3-util" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dfee67245d9ce78a15176728da2280032f0a84b5819a39a953e7ec03cfd9bd7" +dependencies = [ + "serde", +] + [[package]] name = "p384" version = "0.13.1" diff --git a/Cargo.toml b/Cargo.toml index 56dc2b3..a51be32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,8 @@ binius_utils = { git = "https://github.com/IrreducibleOSS/binius.git", rev = "23 bumpalo = "3" groestl_crypto = { package = "groestl", version = "0.10.1" } proptest = "1" +p3-goldilocks = "0.3.0" +p3-field = "0.3.0" rayon = "1" rand = "0.9.1" rustc-hash = "2" diff --git a/Ix/Aiur/Simple.lean b/Ix/Aiur/Simple.lean index 1d4e10a..3a623ef 100644 --- a/Ix/Aiur/Simple.lean +++ b/Ix/Aiur/Simple.lean @@ -44,7 +44,6 @@ def checkAndSimplifyToplevel (toplevel : Toplevel) : Except CheckError TypedDecl | .constructor d c => pure $ typedDecls.insert name (.constructor d c) | .dataType d => pure $ typedDecls.insert name (.dataType d) | .function f => do - let _ ← (checkFunction f) (getFunctionContext f decls) let f ← (checkFunction f) (getFunctionContext f decls) pure $ typedDecls.insert name (.function f) | .gadget g => pure $ typedDecls.insert name (.gadget g) diff --git a/Ix/Aiur2/Bytecode.lean b/Ix/Aiur2/Bytecode.lean new file mode 100644 index 0000000..0fd97c3 --- /dev/null +++ b/Ix/Aiur2/Bytecode.lean @@ -0,0 +1,69 @@ +import Ix.Aiur2.Goldilocks + +namespace Aiur + +namespace Bytecode + +abbrev FunIdx := Nat +abbrev ValIdx := Nat +abbrev SelIdx := Nat + +inductive Op + | const : G → Op + | add : ValIdx → ValIdx → Op + | sub : ValIdx → ValIdx → Op + | mul : ValIdx → ValIdx → Op + | call : FunIdx → Array ValIdx → Op + | store : Array ValIdx → Op + | load : (width : Nat) → ValIdx → Op + deriving Repr + +mutual + inductive Ctrl where + | match : ValIdx → Array (G × Block) → Option Block → Ctrl + | return : SelIdx → Array ValIdx → Ctrl + deriving Inhabited, Repr + + structure Block where + ops : Array Op + ctrl : Ctrl + minSelIncluded: SelIdx + maxSelExcluded: SelIdx + deriving Inhabited, Repr +end + +/-- The circuit layout of a function -/ +structure CircuitLayout where + /-- Bit values that identify which path the computation took. + Exactly one selector must be set. -/ + selectors : Nat + /-- Represent registers that hold temporary values and can be shared by + different circuit paths, since they never overlap. -/ + auxiliaries : Nat + /-- Constraint slots that can be shared in different paths of the circuit. -/ + sharedConstraints : Nat + deriving Inhabited, Repr + +structure Function where + inputSize : Nat + outputSize : Nat + body : Block + circuitLayout: CircuitLayout + deriving Inhabited, Repr + +structure Toplevel where + functions : Array Function + memoryWidths : Array Nat + deriving Repr + +@[extern "c_rs_toplevel_execute_test"] +private opaque Toplevel.executeTest' : + @& Toplevel → @& FunIdx → @& Array G → USize → Array G + +def Toplevel.executeTest (toplevel : Toplevel) (funIdx : FunIdx) (args : Array G) : Array G := + let function := toplevel.functions[funIdx]! + toplevel.executeTest' funIdx args function.outputSize.toUSize + +end Bytecode + +end Aiur diff --git a/Ix/Aiur2/Check.lean b/Ix/Aiur2/Check.lean new file mode 100644 index 0000000..d9b09f5 --- /dev/null +++ b/Ix/Aiur2/Check.lean @@ -0,0 +1,368 @@ +import Ix.Aiur2.Term +import Std.Data.HashSet + +namespace Aiur + +inductive CheckError + | duplicatedDefinition : Global → CheckError + | undefinedGlobal : Global → CheckError + | unboundVariable : Global → CheckError + | notAConstructor : Global → CheckError + | notAValue : Global → CheckError + | notAFunction : Global → CheckError + | notAGadget : Global → CheckError + | cannotApply : Global → CheckError + | notADataType : Global → CheckError + | typeMismatch : Typ → Typ → CheckError + | illegalReturn : CheckError + | nonNumeric : Typ → CheckError + | notAField : Typ → CheckError + | wrongNumArgs : Global → Nat → Nat → CheckError + | notATuple : Typ → CheckError + | indexOoB : Nat → CheckError + | negativeRange : Nat → Nat → CheckError + | rangeOoB : Nat → Nat → CheckError + | incompatiblePattern : Pattern → Typ → CheckError + | differentBindings : List (Local × Typ) → List (Local × Typ) → CheckError + | emptyMatch + | branchMismatch : Typ → Typ → CheckError + | notAPointer : Typ → CheckError + | duplicatedBind : Pattern → CheckError + deriving Repr + +instance : ToString CheckError where + toString e := repr e |>.pretty + +/-- +Constructs a map of declarations from a toplevel, ensuring that there are no duplicate names +for functions and datatypes. +-/ +def Toplevel.mkDecls (toplevel : Toplevel) : Except CheckError Decls := do + let map ← toplevel.functions.foldlM (init := default) + fun acc function => addDecl acc Function.name .function function + toplevel.dataTypes.foldlM (init := map) addDataType +where + ensureUnique name (map : IndexMap Global _) := do + if map.containsKey name then throw $ .duplicatedDefinition name + addDecl {α : Type} map (nameFn : α → Global) (wrapper : α → Declaration) (inner : α) := do + ensureUnique (nameFn inner) map + pure $ map.insert (nameFn inner) (wrapper inner) + addDataType map dataType := do + let dataTypeName := dataType.name + ensureUnique dataTypeName map + let map' := map.insert dataTypeName (.dataType dataType) + dataType.constructors.foldlM (init := map') fun acc (constructor : Constructor) => + addDecl acc (dataTypeName.pushNamespace ∘ Constructor.nameHead) (.constructor dataType) constructor + +structure CheckContext where + decls : Decls + varTypes : Std.HashMap Local Typ + returnType : Typ + +abbrev CheckM := ReaderT CheckContext (Except CheckError) + +/-- Retrieves the type of a global reference. -/ +def refLookup (global : Global) : CheckM Typ := do + let ctx ← read + match ctx.decls.getByKey global with + | some (.function function) => + pure $ .function (function.inputs.map Prod.snd) function.output + | some (.constructor dataType constructor) => + let args := constructor.argTypes + unless args.isEmpty do (throw $ .wrongNumArgs global args.length 0) + pure $ .dataType $ dataType.name + | some _ => throw $ .notAValue global + | none => throw $ .unboundVariable global + +/-- Extend context with locally bound variables. -/ +def bindIdents (bindings : List (Local × Typ)) (ctx : CheckContext) : CheckContext := + { ctx with varTypes := ctx.varTypes.insertMany bindings } + +mutual +partial def inferTerm : Term → CheckM TypedTerm + | .var x => do + -- Retrieves and returns the variable type from the context. + let ctx ← read + match ctx.varTypes[x]? with + | some t => pure $ .mk (.evaluates t) (.var x) + | none => + let Local.str localName := x | unreachable! + let typ := .evaluates (← refLookup (Global.init localName)) + pure $ .mk typ (.var x) + | .ref x => do + let typ := .evaluates (← refLookup x) + pure $ .mk typ (.ref x) + | .ret term => do + -- Ensures that the type of the returned term matches the expected return type. + -- The term is not allowed to have a (nested) return. + -- Returning the type of the term is not necessary because it's already in the context. + let ctx ← read + let inner ← checkNoEscape term ctx.returnType + pure $ .mk .escapes inner + | .data data => do + let (typ, inner) ← inferData data + pure $ .mk (.evaluates typ) inner + | .let pat expr body => do + -- Returns the type of the body, inferred in the context extended with the bound variable type. + -- The bound variable is ensured not to escape. + let (exprTyp, exprInner) ← inferNoEscape expr + let expr' := .mk (.evaluates exprTyp) exprInner + let bindings ← checkPattern pat exprTyp + let body' ← withReader (bindIdents bindings) (inferTerm body) + pure $ .mk body'.typ (.let pat expr' body') + | .match term branches => inferMatch term branches + | .app func@(⟨.str .anonymous unqualifiedFunc⟩) args => do + -- Ensures the function exists in the context and that the arguments, which aren't allowed to + -- escape, match the function's input types. Returns the function's output type. + let ctx ← read + match ctx.varTypes[Local.str unqualifiedFunc]? with + | some (.function inputs output) => do + let args ← checkArgsAndInputs func args inputs + pure $ .mk (.evaluates output) (.app func args) + | some _ => throw $ .notAFunction func + | none => match ctx.decls.getByKey func with + | some (.function function) => do + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure $ .mk (.evaluates function.output) (.app func args) + | some (.constructor dataType constr) => do + let args ← checkArgsAndInputs func args constr.argTypes + pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) + | _ => throw $ .cannotApply func + | .app func args => do + -- Only checks global map if it is not unqualified + let ctx ← read + match ctx.decls.getByKey func with + | some (.function function) => + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure $ .mk (.evaluates function.output) (.app func args) + | some (.constructor dataType constr) => + let args ← checkArgsAndInputs func args constr.argTypes + pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) + | _ => throw $ .cannotApply func + | .add a b => do + let (ctxTyp, a, b) ← checkArith a b + pure $ .mk ctxTyp (.add a b) + | .sub a b => do + let (ctxTyp, a, b) ← checkArith a b + pure $ .mk ctxTyp (.sub a b) + | .mul a b => do + let (ctxTyp, a, b) ← checkArith a b + pure $ .mk ctxTyp (.mul a b) + | .get tup i => do + let (typs, tupInner) ← inferTuple tup + if i < typs.size then + let typ := typs[i]! + let tup := .mk (.evaluates (.tuple typs)) tupInner + pure $ .mk (.evaluates typ) (.get tup i) + else + throw $ .indexOoB i + | .slice tup i j => + -- Retrieves the types of elements in a tuple by a range (checked to be non-negative) and + -- returns them encoded in a `Typ.tuple`. Errors if the index is out of bounds. + if j < i then throw $ .negativeRange i j else do + let (typs, tupInner) ← inferTuple tup + if i < typs.size then + let slice := typs.drop i |>.take (j - i) + let tup := .mk (.evaluates (.tuple typs)) tupInner + pure $ .mk (.evaluates (.tuple slice)) (.slice tup i j) + else + throw $ .rangeOoB i j + | .store term => do + -- Infers the type of the term and returns it, wrapped by a pointer type. + -- The term is not allowed to early return. + let (typ, inner) ← inferNoEscape term + let store := .store (.mk (.evaluates typ) inner) + pure $ .mk (.evaluates (.pointer typ)) store + | .load term => do + -- Ensures that the the type of the term is a pointer type and returns the unwrapped type. + -- The term is not allowed to early return. + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer innerTyp => + let load := .load (.mk (.evaluates typ) inner) + pure $ .mk (.evaluates innerTyp) load + | _ => throw $ .notAPointer typ + | .ptrVal term => do + -- Infers the type of the term, which must be a pointer, but returns `.u64`, as in a cast. + -- The term is not allowed to early return. + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer _ => + let asU64 := .ptrVal (.mk (.evaluates typ) inner) + pure $ .mk (.evaluates .field) asU64 + | _ => throw $ .notAPointer typ + | .ann typ term => do + let inner ← checkNoEscape term typ + pure $ .mk (.evaluates typ) inner +where + /-- + Ensures that there are as many arguments and as expected types and that + the types of the arguments are precisely those expected. + -/ + checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do + let lenArgs := args.length + let lenInputs := inputs.length + unless lenArgs == lenInputs do throw $ .wrongNumArgs func lenArgs lenInputs + let pass := fun (arg, input) => do + let inner ← checkNoEscape arg input + pure $ .mk (.evaluates input) inner + args.zip inputs |>.mapM pass + checkArith a b := do + let (typ, aInner) ← inferNoEscape a + unless (typ == .field) do throw $ .notAField typ + let bInner ← checkNoEscape b typ + let ctxTyp := .evaluates typ + let a := .mk ctxTyp aInner + let b := .mk ctxTyp bInner + pure (ctxTyp, a, b) + +partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do + let (typ', inner) ← inferNoEscape term + unless typ == typ' do throw $ .typeMismatch typ typ' + pure inner + +partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do + let typedTerm ← inferTerm term + match typedTerm.typ with + | .escapes => throw .illegalReturn + | .evaluates type => pure (type, typedTerm.inner) + +partial def inferData : Data → CheckM (Typ × TypedTermInner) + | .field g => pure (.field, .data (.field g)) + | .tuple terms => do + let typsAndInners ← terms.mapM inferNoEscape + let typs := typsAndInners.map Prod.fst + let terms := typsAndInners.map fun (typ, inner) => TypedTerm.mk (.evaluates typ) inner + pure (.tuple typs, .data (.tuple terms)) + +/-- Infers the type of a 'match' expression and ensures its patterns and branches are valid. -/ +partial def inferMatch (term : Term) (branches : List (Pattern × Term)) : CheckM TypedTerm := do + if branches.isEmpty then throw .emptyMatch + let (termTyp, termInner) ← inferNoEscape term + let term := .mk (.evaluates termTyp) termInner + let init := ([], .escapes) + let (branches, typ) ← branches.foldrM (init := init) (checkBranch termTyp) + pure $ .mk typ (.match term branches) +where + checkBranch patTyp branchData acc := do + let (pat, branch) := branchData + let (typedBranches, currentTyp) := acc + let bindings ← checkPattern pat patTyp + withReader (bindIdents bindings) (match currentTyp with + | .escapes => do + let typedBranch ← inferTerm branch + pure (typedBranches.cons (pat, typedBranch), typedBranch.typ) + | .evaluates matchTyp => do + -- Some branch didn't escape, so if this branch doesn't escape it must have the same type + -- as the previous non-escaping branch. + let typedBranch ← inferTerm branch + let typedBranches := typedBranches.cons (pat, typedBranch) + match typedBranch.typ with + | .escapes => pure (typedBranches, currentTyp) + | .evaluates branchTyp => + -- This branch doesn't escape so its type must match the type of the previous non-escaping branch. + unless (matchTyp == branchTyp) do throw $ .branchMismatch matchTyp branchTyp + pure (typedBranches, currentTyp)) + +/-- Checks that a pattern matches a given type and collects its bindings. -/ +partial def checkPattern (pat : Pattern) (typ : Typ) : CheckM $ List (Local × Typ) := do + let binds ← aux pat typ + let locals := binds.map Prod.fst + unless (locals == locals.eraseDups) do throw $ .duplicatedBind pat + pure binds +where + aux pat typ := match (pat, typ) with + | (.var var, _) => pure [(var, typ)] + | (.wildcard, _) + | (.field _, .field) => pure [] + | (.tuple pats, .tuple typs) => do + unless pats.size == typs.size do throw $ .incompatiblePattern pat typ + pats.zip typs |>.foldlM (init := []) fun acc (pat, typ) => acc.append <$> aux pat typ + | (.ref funcName [], typ@(.function ..)) => do + let ctx ← read + let some (.function function) := ctx.decls.getByKey funcName | throw $ .incompatiblePattern pat typ + let typ' := .function (function.inputs.map Prod.snd) function.output + unless typ == typ' do throw $ .typeMismatch typ typ' + pure [] + | (.ref constrRef pats, .dataType dataTypeRef) => do + let ctx ← read + let some (.dataType dataType) := ctx.decls.getByKey dataTypeRef | unreachable! + let some (.constructor dataType' constr) := ctx.decls.getByKey constrRef | throw $ .notAConstructor constrRef + unless dataType == dataType' do throw $ .incompatiblePattern pat typ + let typs := constr.argTypes + let lenPats := pats.length + let lenTyps := typs.length + unless lenPats == lenTyps do throw $ .wrongNumArgs constrRef lenPats lenTyps + pats.zip typs |>.foldlM (init := []) fun acc (pat, typ) => acc.append <$> aux pat typ + | (.or pat pat', _) => do + let bind ← aux pat typ + let bind' ← aux pat' typ + if bind != bind' then throw $ .differentBindings bind bind' else pure bind + | _ => throw $ .incompatiblePattern pat typ + +-- partial def inferNumber (term : Term) : CheckM (Typ × TypedTermInner) := do +-- let (typ, inner) ← inferNoEscape term +-- match typ with +-- | .primitive .u1 +-- | .primitive .u8 +-- | .primitive .u16 +-- | .primitive .u32 +-- | .primitive .u64 => pure (typ, inner) +-- | _ => throw $ .nonNumeric typ + +partial def inferTuple (term : Term) : CheckM (Array Typ × TypedTermInner) := do + let (typ, inner) ← inferNoEscape term + match typ with + | .tuple typs => pure (typs, inner) + | _ => throw $ .notATuple typ +end + +def getFunctionContext (function : Function) (decls : Decls) : CheckContext := + { + decls, + varTypes := .ofList function.inputs + returnType := function.output + } + +/-- +Ensures that all declarations are wellformed by checking that every datatype reference +points to an actual datatype in the toplevel. + +Note: it's assumed that all constructor declarations are properly extracted from the +original datatypes. +-/ +partial def wellFormedDecls (decls : Decls) : Except CheckError Unit := do + let mut visited := default + for (_, decl) in decls.pairs do + match EStateM.run (wellFormedDecl decl) visited with + | .error e _ => throw e + | .ok () visited' => visited := visited' +where + wellFormedDecl : Declaration → EStateM CheckError (Std.HashSet Global) Unit + | .dataType dataType => do + let map ← get + if !map.contains dataType.name then + set $ map.insert dataType.name + dataType.constructors.flatMap (·.argTypes) |>.forM wellFormedType + | .function function => do + wellFormedType function.output + function.inputs.forM fun (_, typ) => wellFormedType typ + -- No need to check constructors because they come from datatype declarations. + | .constructor .. => pure () + wellFormedType : Typ → EStateM CheckError (Std.HashSet Global) Unit + | .tuple typs => typs.forM wellFormedType + | .pointer pointerTyp => wellFormedType pointerTyp + | .dataType dataTypeRef => match decls.getByKey dataTypeRef with + | some (.dataType _) => pure () + | some _ => throw $ .notADataType dataTypeRef + | none => throw $ .undefinedGlobal dataTypeRef + | _ => pure () + +/-- Checks a function to ensure its body's type matches its declared output type. -/ +def checkFunction (function : Function) : CheckM TypedFunction := do + let body ← inferTerm function.body + if let .evaluates typ := body.typ then + unless typ == function.output do throw $ .typeMismatch typ function.output + pure ⟨function.name, function.inputs, function.output, body⟩ + +end Aiur diff --git a/Ix/Aiur2/Compile.lean b/Ix/Aiur2/Compile.lean new file mode 100644 index 0000000..8706cff --- /dev/null +++ b/Ix/Aiur2/Compile.lean @@ -0,0 +1,463 @@ +import Std.Data.HashMap +import Ix.Aiur2.Term +import Ix.Aiur2.Bytecode + +namespace Aiur + +namespace Bytecode + +structure SharedData where + auxiliaries : Nat + constraints : Nat + +def SharedData.maximals (a b : SharedData) : SharedData := { + auxiliaries := a.auxiliaries.max b.auxiliaries + constraints := a.constraints.max b.constraints +} + +structure LayoutMState where + circuitLayout : CircuitLayout + memWidths : Array Nat + deriving Inhabited + +abbrev LayoutM := StateM LayoutMState + +@[inline] def bumpSelectors : LayoutM Unit := + modify fun stt => { stt with + circuitLayout := { stt.circuitLayout with selectors := stt.circuitLayout.selectors + 1 } } + +@[inline] def bumpSharedConstraints (n : Nat) : LayoutM Unit := + modify fun stt => { stt with + circuitLayout := { stt.circuitLayout with sharedConstraints := stt.circuitLayout.sharedConstraints + n } } + +@[inline] def bumpAuxiliaries (n : Nat) : LayoutM Unit := + modify fun stt => { stt with + circuitLayout := { stt.circuitLayout with auxiliaries := stt.circuitLayout.auxiliaries + n } } + +@[inline] def addMemWidth (memWidth : Nat) : LayoutM Unit := + modify fun stt => + let memWidths := if stt.memWidths.contains memWidth + then stt.memWidths + else stt.memWidths.push memWidth + { stt with memWidths } + +def getSharedData : LayoutM SharedData := do + let stt ← get + pure { + auxiliaries := stt.circuitLayout.auxiliaries + constraints := stt.circuitLayout.sharedConstraints + } + +def setSharedData (sharedData : SharedData) : LayoutM Unit := + modify fun stt => { stt with circuitLayout := { stt.circuitLayout with + auxiliaries := sharedData.auxiliaries + sharedConstraints := sharedData.constraints } } + +def opLayout : Bytecode.Op → LayoutM Unit + | .store values => addMemWidth values.size + | .load width _ => addMemWidth width + | _ => pure () -- TODO + +partial def blockLayout (block : Bytecode.Block) : LayoutM Unit := do -- TODO + block.ops.forM opLayout + match block.ctrl with + | .match _ branches defaultBranch => + let initSharedData ← getSharedData + let mut maximalSharedData := initSharedData + for (_, block) in branches do + setSharedData initSharedData + -- This auxiliary is for proving inequality + bumpAuxiliaries 1 + blockLayout block + let blockSharedData ← getSharedData + maximalSharedData := maximalSharedData.maximals blockSharedData + if let some defaultBlock := defaultBranch then + setSharedData initSharedData + blockLayout defaultBlock + let defaultBlockSharedData ← getSharedData + maximalSharedData := maximalSharedData.maximals defaultBlockSharedData + setSharedData maximalSharedData + | .return _ out => + -- One selector per return + bumpSelectors + -- Each output must equal its respective return variable, + -- thus one constraint per return variable + bumpSharedConstraints out.size + +end Bytecode + +structure DataTypeLayout where + size: Nat + deriving Inhabited + +structure FunctionLayout where + index: Nat + inputSize : Nat + outputSize : Nat + offsets: Array Nat + deriving Inhabited + +structure ConstructorLayout where + index: Nat + size: Nat + offsets: Array Nat + deriving Inhabited + +structure GadgetLayout where + index : Nat + outputSize : Nat + deriving Inhabited + +inductive Layout + | dataType : DataTypeLayout → Layout + | function : FunctionLayout → Layout + | constructor : ConstructorLayout → Layout + | gadget : GadgetLayout → Layout + deriving Inhabited + +abbrev LayoutMap := Std.HashMap Global Layout + +def TypedDecls.layoutMap (decls : TypedDecls) : LayoutMap := + let pass := fun (layoutMap, funcIdx, gadgetIdx) (_, v) => match v with + | .dataType dataType => + let dataTypeSize := dataType.size decls + let layoutMap := layoutMap.insert dataType.name (.dataType { size := dataTypeSize }) + let pass := fun (acc, index) constructor => + let offsets := constructor.argTypes.foldl (init := #[0]) + (fun offsets typ => offsets.push (offsets[offsets.size - 1]! + typ.size decls)) + let decl := .constructor { size := dataTypeSize, offsets, index } + let name := dataType.name.pushNamespace constructor.nameHead + (acc.insert name decl, index + 1) + let (layoutMap, _) := dataType.constructors.foldl (init := (layoutMap, 0)) pass + (layoutMap, funcIdx, gadgetIdx) + | .function function => + let inputSize := function.inputs.foldl (init := 0) (fun acc (_, typ) => acc + typ.size decls) + let outputSize := function.output.size decls + let offsets := function.inputs.foldl (init := #[0]) + (fun offsets (_, typ) => offsets.push (offsets[offsets.size - 1]! + typ.size decls)) + let layoutMap := layoutMap.insert function.name (.function { index := funcIdx, inputSize, outputSize, offsets }) + (layoutMap, funcIdx + 1, gadgetIdx) + | .constructor .. => (layoutMap, funcIdx, gadgetIdx) + let (layoutMap, _) := decls.foldl pass ({}, 0, 0) + layoutMap + +structure CompiledFunction where + inputSize : Nat + outputSize : Nat + body : Bytecode.Block + +def typSize (layoutMap : LayoutMap) : Typ → Nat +| Typ.field .. => 1 +| Typ.pointer .. => 1 +| Typ.function .. => 1 +| Typ.tuple ts => ts.foldl (init := 0) (fun acc t => acc + typSize layoutMap t) +| Typ.dataType g => match layoutMap[g]? with + | some (.dataType layout) => layout.size + | _ => unreachable! + +structure CompilerState where + index : Bytecode.ValIdx + ops : Array Bytecode.Op + returnIdent : Bytecode.SelIdx + deriving Inhabited + +def pushOp (op : Bytecode.Op) (size : Nat := 1) : StateM CompilerState (Array Bytecode.ValIdx) := + modifyGet (fun s => + let index := s.index + let ops := s.ops + (Array.range' index size, { s with index := index + size, ops := ops.push op})) + +def extractOps : StateM CompilerState (Array Bytecode.Op) := + modifyGet (fun s => (s.ops, {s with ops := #[]})) + +partial def toIndex + (layoutMap : LayoutMap) + (bindings : Std.HashMap Local (Array Bytecode.ValIdx)) + (term : TypedTerm) : StateM CompilerState (Array Bytecode.ValIdx) := + let typ := term.typ.unwrap + match term.inner with + | .ret .. => panic! "should not happen after typechecking" + | .match .. => panic! "non-tail `match` not yet implemented" + | .var name => do + pure (bindings[name]!) + | .ref name => match layoutMap[name]! with + | .function layout => do + pushOp (.const (.ofNat' gSize layout.index)) + | .constructor layout => do + let size := layout.size + let paddingOp := .const (.ofNat' gSize layout.index) + let index ← pushOp paddingOp + if index.size < size then + let padding := (← pushOp paddingOp)[0]! + pure $ index ++ Array.mkArray (size - index.size) padding + else + pure index + | _ => panic! "should not happen after typechecking" + | .data (.field g) => pushOp (Bytecode.Op.const g) + | .data (.tuple args) => + -- TODO use `buildArgs` + let append arg acc := do + pure (acc.append (← toIndex layoutMap bindings arg)) + args.foldrM (init := #[]) append + | .let (.var var) val bod => do + let val ← toIndex layoutMap bindings val + toIndex layoutMap (bindings.insert var val) bod + | .let .. => panic! "should not happen after simplifying" + -- | .xor a b => do + -- let a ← toIndex layoutMap bindings a + -- assert! (a.size == 1) + -- let b ← toIndex layoutMap bindings b + -- assert! (b.size == 1) + -- pushOp (.xor (a.get' 0) (b.get' 0)) + | .add a b => do + let a ← toIndex layoutMap bindings a + assert! (a.size == 1) + let b ← toIndex layoutMap bindings b + assert! (b.size == 1) + pushOp (.add (a[0]!) (b[0]!)) + | .sub a b => do + let a ← toIndex layoutMap bindings a + assert! (a.size == 1) + let b ← toIndex layoutMap bindings b + assert! (b.size == 1) + pushOp (.sub (a[0]!) (b[0]!)) + | .mul a b => do + let a ← toIndex layoutMap bindings a + assert! (a.size == 1) + let b ← toIndex layoutMap bindings b + assert! (b.size == 1) + pushOp (.mul (a[0]!) (b[0]!)) + -- | .and a b => do + -- let a ← toIndex layoutMap bindings a + -- assert! (a.size == 1) + -- let b ← toIndex layoutMap bindings b + -- assert! (b.size == 1) + -- pushOp (.and (a.get' 0) (b.get' 0)) + | .app name@(⟨.str .anonymous unqualifiedName⟩) args => + match bindings.get? (.str unqualifiedName) with + | some _ => panic! "dynamic calls not yet implemented" + | none => match layoutMap[name]! with + | .function layout => do + let args ← buildArgs args + pushOp (Bytecode.Op.call layout.index args) layout.outputSize + | .constructor layout => do + let size := layout.size + let index ← pushOp (.const (.ofNat' gSize layout.index)) + let index ← buildArgs args index + if index.size < size then + let padding := (← pushOp (.const (.ofNat' gSize 0)))[0]! + pure $ index ++ Array.mkArray (size - index.size) padding + else + pure index + | _ => panic! "should not happen after typechecking" + | .app name args => match layoutMap[name]! with + | .function layout => do + let args ← buildArgs args + pushOp (Bytecode.Op.call layout.index args) layout.outputSize + | .constructor layout => do + let size := layout.size + let index ← pushOp (.const (.ofNat' gSize layout.index)) + let index ← buildArgs args index + if index.size < size then + let padding := (← pushOp (.const (.ofNat' gSize 0)))[0]! + pure $ index ++ Array.mkArray (size - index.size) padding + else + pure index + | _ => panic! "should not happen after typechecking" + -- | .preimg name@(⟨.str .anonymous unqualifiedName⟩) out => + -- match bindings.get? (.str unqualifiedName) with + -- | some _ => panic! "dynamic preimage not yet implemented" + -- | none => match layoutMap.get' name with + -- | .function layout => do + -- let out ← toIndex layoutMap bindings out + -- pushOp (Bytecode.Op.preimg layout.index out layout.inputSize) layout.inputSize + -- | _ => panic! "should not happen after typechecking" + -- | .preimg name out => match layoutMap.get' name with + -- | .function layout => do + -- let out ← toIndex layoutMap bindings out + -- pushOp (Bytecode.Op.preimg layout.index out layout.inputSize) layout.inputSize + -- | _ => panic! "should not happen after typechecking" + -- | .ffi name args => match layoutMap.get' name with + -- | .gadget layout => do + -- let args ← buildArgs args + -- pushOp (Bytecode.Op.ffi layout.index args layout.outputSize) layout.outputSize + -- | _ => panic! "should not happen after typechecking" + | .get arg i => do + let typs := (match arg.typ with + | .evaluates (.tuple typs) => typs + | _ => panic! "should not happen after typechecking") + let offset := (typs.extract 0 i).foldl (init := 0) + fun acc typ => typSize layoutMap typ + acc + let arg ← toIndex layoutMap bindings arg + let length := typSize layoutMap typ + pure $ arg.extract offset (offset + length) + | .slice arg i j => do + let typs := (match arg.typ with + | .evaluates (.tuple typs) => typs + | _ => panic! "should not happen after typechecking") + let offset := (typs.extract 0 i).foldl (init := 0) + fun acc typ => typSize layoutMap typ + acc + let length := (typs.extract i j).foldl (init := 0) + fun acc typ => typSize layoutMap typ + acc + let arg ← toIndex layoutMap bindings arg + pure $ arg.extract offset (offset + length) + | .store arg => do + let args ← toIndex layoutMap bindings arg + pushOp (Bytecode.Op.store args) + | .load ptr => do + let size := match ptr.typ.unwrap with + | .pointer typ => typSize layoutMap typ + | _ => unreachable! + let ptr ← toIndex layoutMap bindings ptr + assert! (ptr.size == 1) + pushOp (Bytecode.Op.load size ptr[0]!) size + | .ptrVal ptr => toIndex layoutMap bindings ptr + -- | .trace str expr => do + -- let arr ← toIndex layoutMap bindings expr + -- let op := .trace str arr + -- modify (fun state => { state with ops := state.ops.push op}) + -- pure arr + where + buildArgs (args : List TypedTerm) (init := #[]) := + let append acc arg := do + pure (acc.append (← toIndex layoutMap bindings arg)) + args.foldlM (init := init) append + +mutual + +partial def TypedTerm.compile + (term : TypedTerm) + (returnTyp : Typ) + (layoutMap : LayoutMap) + (bindings : Std.HashMap Local (Array Bytecode.ValIdx)) +: StateM CompilerState Bytecode.Block := match term.inner with + | .let (.var var) val bod => do + let val ← toIndex layoutMap bindings val + bod.compile returnTyp layoutMap (bindings.insert var val) + | .let .. => panic! "should not happen after simplifying" + | .match term cases => + match term.typ.unwrapOr returnTyp with + -- Also do this for tuple-like (one constructor only) datatypes + | .tuple typs => match cases with + | [(.tuple vars, branch)] => do + let bindArgs bindings pats typs idxs := + let n := pats.size + let init := (bindings, 0) + let (bindings, _) := (List.range n).foldl (init := init) fun (bindings, offset) i => + match pats[i]! with + | .var var => + let len := typSize layoutMap typs[i]! + let new_offset := offset + len + (bindings.insert var (idxs.extract offset new_offset), new_offset) + | _ => panic! "should not happen after simplification" + bindings + let idxs ← toIndex layoutMap bindings term + let bindings := bindArgs bindings vars typs idxs + branch.compile returnTyp layoutMap bindings + | _ => unreachable! + | _ => do + let idxs ← toIndex layoutMap bindings term + let ops ← extractOps + let minSelIncluded := (← get).returnIdent + let (cases, default) ← cases.foldlM (init := default) + (addCase layoutMap bindings returnTyp idxs) + let maxSelExcluded := (← get).returnIdent + let ctrl := .match (idxs[0]!) cases default + pure { ops, ctrl, minSelIncluded, maxSelExcluded } + | .ret term => do + let idxs ← toIndex layoutMap bindings term + let state ← get + let state := { state with returnIdent := state.returnIdent + 1 } + set state + let ops := state.ops + let id := state.returnIdent + pure { ops, ctrl := .return (id - 1) idxs, minSelIncluded := id - 1, maxSelExcluded := id } + | _ => do + let idxs ← toIndex layoutMap bindings term + let state ← get + let state := { state with returnIdent := state.returnIdent + 1 } + set state + let ops := state.ops + let id := state.returnIdent + pure { ops, ctrl := .return (id - 1) idxs, minSelIncluded := id - 1, maxSelExcluded := id } + +partial def addCase + (layoutMap : LayoutMap) + (bindings : Std.HashMap Local (Array Bytecode.ValIdx)) + (returnTyp : Typ) + (idxs : Array Bytecode.ValIdx) +: (Array (G × Bytecode.Block) × Option Bytecode.Block) → + (Pattern × TypedTerm) → + StateM CompilerState (Array (G × Bytecode.Block) × Option Bytecode.Block) := fun (cases, default) (pat, term) => + -- If simplified, only one default will exist, and it will appear at the end of the match + assert! default.isNone + match pat with + | .field g => do + let initState ← get + let term ← term.compile returnTyp layoutMap bindings + set { initState with returnIdent := (← get).returnIdent } + let cases' := cases.push (g, term) + pure (cases', default) + | .ref global pats => do + let layout := layoutMap[global]! + let (index, offsets) := match layout with + | .function layout => (layout.index, layout.offsets) + | .constructor layout => (layout.index, layout.offsets) + | .dataType _ + | .gadget _ => panic! "impossible after typechecking" + let bindArgs bindings pats offsets idxs := + let n := pats.length + let bindings := (List.range n).foldl (init := bindings) fun bindings i => + let pat := (pats[i]!) + -- the `+ 1` is to account for the tag + let offset := (offsets[i]!) + 1 + let next_offset := (offsets[(i + 1)]!) + 1 + match pat with + | .var var => + bindings.insert var (idxs.extract offset next_offset) + | _ => panic! "should not happen after simplification" + bindings + let bindings := bindArgs bindings pats offsets idxs + let initState ← get + let term ← term.compile returnTyp layoutMap bindings + set { initState with returnIdent := (← get).returnIdent } + let cases' := cases.push (.ofNat' gSize index, term) + pure (cases', default) + | .wildcard => do + let initState ← get + let term ← term.compile returnTyp layoutMap bindings + set { initState with returnIdent := (← get).returnIdent } + pure (cases, .some term) + | _ => unreachable! + +end + +def TypedFunction.compile (layoutMap : LayoutMap) (f : TypedFunction) : + CompiledFunction × Bytecode.LayoutMState := + let (inputSize, outputSize) := match layoutMap[f.name]? with + | some (.function layout) => (layout.inputSize, layout.outputSize) + | _ => panic! s!"`{f.name}` should be a function" + let (index, bindings) := f.inputs.foldl (init := (0, default)) + fun (index, bindings) (arg, typ) => + let len := typSize layoutMap typ + let indices := Array.range' index len + (index + len, bindings.insert arg indices) + let state := { index, returnIdent := 0, ops := #[] } + let body := f.body.compile f.output layoutMap bindings |>.run' state + let (_, layoutMState) := Bytecode.blockLayout body |>.run default + ({ inputSize, outputSize, body }, layoutMState) + +def TypedDecls.compile (decls : TypedDecls) : Bytecode.Toplevel := + let layout := decls.layoutMap + let (functions, memWidths) := decls.foldl (init := (#[], #[])) + fun acc@(functions, memWidths) (_, decl) => match decl with + | .function function => + let (compiledFunction, layoutMState) := function.compile layout + let function := { compiledFunction with circuitLayout := layoutMState.circuitLayout } + let memWidths := layoutMState.memWidths.foldl (init := memWidths) fun memWidths memWidth => + if memWidths.contains memWidth then memWidths else memWidths.push memWidth + (functions.push function, memWidths) + | _ => acc + ⟨functions, memWidths.qsort⟩ + +end Aiur diff --git a/Ix/Aiur2/Goldilocks.lean b/Ix/Aiur2/Goldilocks.lean new file mode 100644 index 0000000..6565181 --- /dev/null +++ b/Ix/Aiur2/Goldilocks.lean @@ -0,0 +1,6 @@ +namespace Aiur + +abbrev gSize := 0xFFFFFFFF00000001 +abbrev G := Fin gSize + +end Aiur diff --git a/Ix/Aiur2/Match.lean b/Ix/Aiur2/Match.lean new file mode 100644 index 0000000..f3f7d82 --- /dev/null +++ b/Ix/Aiur2/Match.lean @@ -0,0 +1,218 @@ +import Ix.Aiur2.Term +import Ix.SmallMap + +namespace Aiur + +abbrev TermId := Nat +abbrev UniqTerm := TermId × Term + +inductive SPattern + | field : G → SPattern + | ref : Global → Array Local → SPattern + | tuple : Array Local → SPattern + deriving BEq, Hashable, Inhabited + +structure Clause where + pat : SPattern + guards : Array (Pattern × UniqTerm) + body : UniqTerm + deriving Inhabited + +structure ExtTerm where + renames : Array (Local × Term) + value : Term + deriving Inhabited + +structure Row where + clauses : Array Clause + body : ExtTerm + uniqId : TermId + deriving Inhabited + +structure Diagnostics where + missing : Bool + reachable : List Term + +structure Compiler where + uniqId : TermId + decls : Decls + diagnostics : Diagnostics + +abbrev CompilerM := StateM Compiler + +def setId (id : TermId) : CompilerM Unit := + modify fun stt => { stt with uniqId := id } + +def newId : CompilerM TermId := do + let varId ← Compiler.uniqId <$> get + modify fun stt => { stt with uniqId := stt.uniqId + 1 } + pure varId + +def dnfProd (branches: List $ Pattern × UniqTerm) (body : ExtTerm) : CompilerM (Array Row) := do + let initId ← Compiler.uniqId <$> get + let rec aux := fun renames clauses branches body => match branches with + | [] => do + let id ← Compiler.uniqId <$> get + let row := ⟨clauses, { body with renames := body.renames ++ renames }, id⟩ + setId initId + pure #[row] + | (.or patL patR, term) :: rest => do + let rowsL ← aux renames clauses ((patL, term) :: rest) body + let rowsR ← aux renames clauses ((patR, term) :: rest) body + pure $ rowsL ++ rowsR + | (.wildcard, _) :: rest => aux renames clauses rest body + | (.var var, (_, term)) :: rest => aux (renames.push (var, term)) clauses rest body + | (.field g, term) :: rest => aux renames (clauses.push ⟨.field g, #[], term⟩) rest body + | (.tuple args, term) :: rest => do + let (vars, guards) ← flattenArgs args + let clause := ⟨.tuple vars, guards, term⟩ + aux renames (clauses.push clause) rest body + | (.ref global args, term) :: rest => do + let (vars, guards) ← flattenArgs args.toArray + let clause := ⟨.ref global vars, guards, term⟩ + aux renames (clauses.push clause) rest body + aux Array.empty Array.empty branches body +where + flattenArgs args := do + let varIds ← args.mapM fun _ => newId + let guards := args.zip varIds |>.map fun (arg, id) => (arg, (id, .var (.idx id))) + pure (varIds.map .idx, guards) + +inductive Decision + | success : ExtTerm → Decision + | failure + | switch : Local → List (SPattern × Decision) → Decision → Decision + | let : Local → Term → Decision → Decision + deriving Inhabited + +def modifyDiagnostics (f : Diagnostics → Diagnostics) : CompilerM Unit := + modify fun stt => { stt with diagnostics := f stt.diagnostics } + +def patTypeLength (decls : Decls) : SPattern → Nat + | .field _ => gSize + | .tuple _ => 1 + | .ref global _ => typeLookup global |>.constructors.length +where + typeLookup (global : Global) := + match global.popNamespace with + | some (_, enum) => match decls.getByKey enum with + | some (.dataType typ) => typ + | _ => unreachable! + | none => unreachable! + +def extractSPatterns (rows : Array Row) (term : UniqTerm) : SmallMap SPattern (Array Row) := + rows.foldl (init := default) processRow +where + processRow map row := row.clauses.foldl (init := map) processClause + processClause map clause := + if term.fst == clause.body.fst then map.insert clause.pat #[] else map + +def removeFirstInstance (find : α → Bool) (vec : Array α) : Option (α × Array α) := + match h : vec.findIdx? find with + | none => none + | some i => + let filtered := vec.extract 0 i ++ vec.extract (i + 1) vec.size + have := Array.findIdx?_eq_some_iff_findIdx_eq.mp h |>.left + some (vec[i], filtered) + +def rowRemoveClause (row : Row) (term : UniqTerm) : Option (Clause × Row) := + let id := term.fst + match removeFirstInstance (·.body.fst == id) row.clauses with + | some (deletedClause, restClauses) => some (deletedClause, { row with clauses := restClauses }) + | none => none + +def switch (cases : List (SPattern × Decision)) (fallback : Decision) : UniqTerm → Decision + | (_, .var var) => .switch var cases fallback + | (id, term) => let var := .idx id; .let var term (.switch var cases fallback) + +mutual + +partial def compileSwitch (rows : Array Row) (term : UniqTerm) : CompilerM Decision := do + let stt ← get + let numCases := patTypeLength stt.decls rows[0]!.clauses[0]!.pat + let spatMap := extractSPatterns rows term + let size := spatMap.size + assert! size <= numCases + if size == numCases then + let (rowMap, _) ← rows.foldlM (init := (spatMap, #[])) processRow + let cases ← rowMap.toList.mapM fun (pat, rows) => Prod.mk pat <$> compileRows rows + setId stt.uniqId + pure $ switch cases .failure term + else + let (rowMap, fallbackRows) ← rows.foldlM (init := (spatMap, #[])) processRow + let cases ← rowMap.toList.mapM fun (pat, rows) => Prod.mk pat <$> compileRows rows + let fallback ← compileRows fallbackRows + setId stt.uniqId + pure $ switch cases fallback term +where + processRow pair row := + let (rowMap, fallbackRows) := pair + match rowRemoveClause row term with + | some (clause, row') => do + setId row.uniqId + let newRows ← dnfProd clause.guards.toList row'.body + let newRows := newRows.map (fun r => { r with clauses := r.clauses ++ row'.clauses }) + let updatedMap := rowMap.update clause.pat (· ++ newRows) + pure (updatedMap, fallbackRows) + | none => pure (rowMap.map (·.push row), fallbackRows.push row) + +partial def compileRows (rows : Array Row) : CompilerM Decision := + match rows[0]? with + | some row => match row.clauses[0]? with + | some clause => compileSwitch rows clause.body + | none => do + modifyDiagnostics fun d => { d with reachable := row.body.value :: d.reachable } + pure $ .success row.body + | none => do + modifyDiagnostics fun d => { d with missing := true } + pure .failure + +end + +def compile (term : Term) (rules : List (Pattern × Term)) : CompilerM (Decision × Diagnostics) := do + let id ← newId + let rows ← rules.foldlM (init := #[]) fun acc rule => do + let newRows ← fromRule id rule + pure $ acc ++ newRows + let tree ← compileRows rows + let diagnostics ← Compiler.diagnostics <$> get + pure (tree, diagnostics) +where + fromRule id rule := + let (pat, bod) := rule + dnfProd [(pat, (id, term))] ⟨#[], bod⟩ + +def runWithNewCompiler (typs : Decls) (f : CompilerM α) : α := + StateT.run' f ⟨0, typs, ⟨false, []⟩⟩ + +def runMatchCompiler (typs : Decls) (term : Term) (rules : List (Pattern × Term)) : + Decision × Diagnostics := + runWithNewCompiler typs (compile term rules) + +def spatternToPattern : SPattern → Pattern + | .field g => .field g + | .ref global vars => .ref global (vars.map .var).toList + | .tuple vars => .tuple (vars.map .var) + +mutual + +partial def branchesToTerm (branches : List (SPattern × Decision)) (dec : Decision) : + List (Pattern × Term) := + let toTerm := fun (spat, tree) => + decisionToTerm tree >>= fun term => some (spatternToPattern spat, term) + match decisionToTerm dec with + | some defTerm => branches.filterMap toTerm ++ [(Pattern.wildcard, defTerm)] + | none => branches.filterMap toTerm + +partial def decisionToTerm : Decision → Option Term + | .success ⟨renames, value⟩ => + some $ renames.foldr (init := value) fun (x, y) acc => .let (.var x) y acc + | .switch var branches dec => some $ .match (.var var) (branchesToTerm branches dec) + | .let var term body => do + let body' ← decisionToTerm body + some $ .let (.var var) term body' + | .failure => none + +end + +end Aiur diff --git a/Ix/Aiur2/Meta.lean b/Ix/Aiur2/Meta.lean new file mode 100644 index 0000000..deac0b5 --- /dev/null +++ b/Ix/Aiur2/Meta.lean @@ -0,0 +1,235 @@ +import Lean +import Ix.Aiur2.Term + +namespace Aiur + +open Lean Elab Meta + +abbrev ElabStxCat name := TSyntax name → TermElabM Expr + +declare_syntax_cat pattern +syntax ("." noWs)? ident : pattern +syntax "_" : pattern +syntax ident "(" pattern (", " pattern)* ")" : pattern +syntax num : pattern +syntax "(" pattern (", " pattern)* ")" : pattern +syntax pattern "|" pattern : pattern + +def elabListCore (head : α) (tail : Array α) (elabFn : α → TermElabM Expr) + (listEltType : Expr) (isArray := false) : TermElabM Expr := do + let mut elaborated := Array.mkEmpty (tail.size + 1) + elaborated := elaborated.push $ ← elabFn head + for elt in tail do + elaborated := elaborated.push $ ← elabFn elt + if isArray + then mkArrayLit listEltType elaborated.toList + else mkListLit listEltType elaborated.toList + +def elabList (head : α) (tail : Array α) (elabFn : α → TermElabM Expr) + (listEltTypeName : Name) (isArray := false) : TermElabM Expr := + elabListCore head tail elabFn (mkConst listEltTypeName) isArray + +def elabEmptyList (listEltTypeName : Name) : TermElabM Expr := + mkListLit (mkConst listEltTypeName) [] + +def elabG (n : TSyntax `num) : TermElabM Expr := + mkAppM ``Fin.ofNat' #[mkConst ``gSize, mkNatLit n.getNat] + +partial def elabPattern : ElabStxCat `pattern + | `(pattern| $v:ident($p:pattern $[, $ps:pattern]*)) => do + let g ← mkAppM ``Global.mk #[toExpr v.getId] + mkAppM ``Pattern.ref #[g, ← elabList p ps elabPattern ``Pattern] + | `(pattern| .$i:ident) => do + let g ← mkAppM ``Global.mk #[toExpr i.getId] + mkAppM ``Pattern.ref #[g, ← elabEmptyList ``Pattern] + | `(pattern| $i:ident) => match i.getId with + | .str .anonymous name => do + mkAppM ``Pattern.var #[← mkAppM ``Local.str #[toExpr name]] + | name@(.str _ _) => do + let g ← mkAppM ``Global.mk #[toExpr name] + mkAppM ``Pattern.ref #[g, ← elabEmptyList ``Pattern] + | _ => throw $ .error i "Illegal pattern name" + | `(pattern| _) => pure $ mkConst ``Pattern.wildcard + | `(pattern| $n:num) => do mkAppM ``Pattern.field #[← elabG n] + | `(pattern| ($p:pattern $[, $ps:pattern]*)) => do + mkAppM ``Pattern.tuple #[← elabList p ps elabPattern ``Pattern true] + | `(pattern| $p₁:pattern | $p₂:pattern) => do + mkAppM ``Pattern.or #[← elabPattern p₁, ← elabPattern p₂] + | stx => throw $ .error stx "Invalid syntax for pattern" + +declare_syntax_cat typ +syntax "G" : typ +syntax "(" typ (", " typ)* ")" : typ +syntax "&" typ : typ +syntax ("." noWs)? ident : typ +syntax "fn" "(" ")" " -> " typ : typ +syntax "fn" "(" typ (", " typ)* ")" " -> " typ : typ + +partial def elabTyp : ElabStxCat `typ + | `(typ| G) => pure $ mkConst ``Typ.field + | `(typ| ($t:typ $[, $ts:typ]*)) => do + mkAppM ``Typ.tuple #[← elabList t ts elabTyp ``Typ true] + | `(typ| &$t:typ) => do + mkAppM ``Typ.pointer #[← elabTyp t] + | `(typ| $[.]?$i:ident) => do + let g ← mkAppM ``Global.mk #[toExpr i.getId] + mkAppM ``Typ.dataType #[g] + | `(typ| fn() -> $t:typ) => do + mkAppM ``Typ.function #[← elabEmptyList ``Typ, ← elabTyp t] + | `(typ| fn($t$[, $ts:typ]*) -> $t':typ) => do + mkAppM ``Typ.function #[← elabList t ts elabTyp ``Typ, ← elabTyp t'] + | stx => throw $ .error stx "Invalid syntax for type" + +declare_syntax_cat trm +syntax ("." noWs)? ident : trm +syntax num : trm +syntax "(" trm (", " trm)* ")" : trm +syntax "return " trm : trm +syntax "let " pattern " = " trm "; " trm : trm +syntax "match " trm " { " (pattern " => " trm ", ")+ " }" : trm +syntax ("." noWs)? ident "(" ")" : trm +syntax ("." noWs)? ident "(" trm (", " trm)* ")" : trm +syntax "add" "(" trm ", " trm ")" : trm +syntax "sub" "(" trm ", " trm ")" : trm +syntax "mul" "(" trm ", " trm ")" : trm +syntax "get" "(" trm ", " num ")" : trm +syntax "slice" "(" trm ", " num ", " num ")" : trm +syntax "store" "(" trm ")" : trm +syntax "load" "(" trm ")" : trm +syntax "ptr_val" "(" trm ")" : trm +syntax trm ": " typ : trm + +partial def elabTrm : ElabStxCat `trm + | `(trm| .$i:ident) => do + mkAppM ``Term.ref #[← mkAppM ``Global.mk #[toExpr i.getId]] + | `(trm| $i:ident) => match i.getId with + | .str .anonymous name => do + mkAppM ``Term.var #[← mkAppM ``Local.str #[toExpr name]] + | name@(.str _ _) => do + mkAppM ``Term.ref #[← mkAppM ``Global.mk #[toExpr name]] + | _ => throw $ .error i "Illegal name" + | `(trm| $n:num) => do + let data ← mkAppM ``Data.field #[← elabG n] + mkAppM ``Term.data #[data] + | `(trm| ($t:trm $[, $ts:trm]*)) => do + let data ← mkAppM ``Data.tuple #[← elabList t ts elabTrm ``Term true] + mkAppM ``Term.data #[data] + | `(trm| return $t:trm) => do + mkAppM ``Term.ret #[← elabTrm t] + | `(trm| let $p:pattern = $t:trm; $t':trm) => do + mkAppM ``Term.let #[← elabPattern p, ← elabTrm t, ← elabTrm t'] + | `(trm| match $t:trm {$[$ps:pattern => $ts:trm,]*}) => do + let mut prods := Array.mkEmpty (ps.size + 1) + for (p, t) in ps.zip ts do + prods := prods.push $ ← mkAppM ``Prod.mk #[← elabPattern p, ← elabTrm t] + let prodType ← mkAppM ``Prod #[mkConst ``Pattern, mkConst ``Term] + mkAppM ``Term.match #[← elabTrm t, ← mkListLit prodType prods.toList] + | `(trm| $[.]?$f:ident ()) => do + let g ← mkAppM ``Global.mk #[toExpr f.getId] + mkAppM ``Term.app #[g, ← elabEmptyList ``Term] + | `(trm| $[.]?$f:ident ($a:trm $[, $as:trm]*)) => do + let g ← mkAppM ``Global.mk #[toExpr f.getId] + mkAppM ``Term.app #[g, ← elabList a as elabTrm ``Term] + | `(trm| add($a:trm, $b:trm)) => do + mkAppM ``Term.add #[← elabTrm a, ← elabTrm b] + | `(trm| sub($a:trm, $b:trm)) => do + mkAppM ``Term.sub #[← elabTrm a, ← elabTrm b] + | `(trm| mul($a:trm, $b:trm)) => do + mkAppM ``Term.mul #[← elabTrm a, ← elabTrm b] + | `(trm| get($a:trm, $i:num)) => do + mkAppM ``Term.get #[← elabTrm a, toExpr i.getNat] + | `(trm| slice($a:trm, $i:num, $j:num)) => do + mkAppM ``Term.slice #[← elabTrm a, toExpr i.getNat, toExpr j.getNat] + | `(trm| store($a:trm)) => do + mkAppM ``Term.store #[← elabTrm a] + | `(trm| load($a:trm)) => do + mkAppM ``Term.load #[← elabTrm a] + | `(trm| ptr_val($a:trm)) => do + mkAppM ``Term.ptrVal #[← elabTrm a] + | `(trm| $v:trm : $t:typ) => do + mkAppM ``Term.ann #[← elabTyp t, ← elabTrm v] + | stx => throw $ .error stx "Invalid syntax for term" + +declare_syntax_cat constructor +syntax ident : constructor +syntax ident "(" typ (", " typ)* ")" : constructor + +def elabConstructor : ElabStxCat `constructor + | `(constructor| $i:ident) => match i.getId with + | .str .anonymous name => do + mkAppM ``Constructor.mk #[toExpr name, ← elabEmptyList ``Typ] + | _ => throw $ .error i "Illegal constructor name" + | `(constructor| $i:ident($t:typ$[, $ts:typ]*)) => match i.getId with + | .str .anonymous name => do + mkAppM ``Constructor.mk #[toExpr name, ← elabList t ts elabTyp ``Typ] + | _ => throw $ .error i "Illegal constructor name" + | stx => throw $ .error stx "Invalid syntax for constructor" + +declare_syntax_cat data_type +syntax "enum " ident : data_type +syntax "enum " ident "{" constructor (", " constructor)* "}" : data_type + +def elabDataType : ElabStxCat `data_type + | `(data_type| enum $n:ident) => do + let g ← mkAppM ``Global.mk #[toExpr n.getId] + mkAppM ``DataType.mk #[g, ← elabEmptyList ``Constructor] + | `(data_type| enum $n:ident {$c:constructor $[, $cs:constructor]*}) => do + let g ← mkAppM ``Global.mk #[toExpr n.getId] + mkAppM ``DataType.mk #[g, ← elabList c cs elabConstructor ``Constructor] + | stx => throw $ .error stx "Invalid syntax for data type" + +declare_syntax_cat bind +syntax ident ": " typ : bind + +def elabBind : ElabStxCat `bind + | `(bind| $i:ident: $t:typ) => match i.getId with + | .str .anonymous name => do + mkAppM ``Prod.mk #[← mkAppM ``Local.str #[toExpr name], ← elabTyp t] + | _ => throw $ .error i "Illegal variable name" + | stx => throw $ .error stx "Invalid syntax for binding" + +declare_syntax_cat function +syntax "fn " ident "(" ")" " -> " typ "{" trm "}" : function +syntax "fn " ident "(" bind (", " bind)* ")" " -> " typ "{" trm "}" : function + +def elabFunction : ElabStxCat `function + | `(function| fn $i:ident() -> $ty:typ {$t:trm}) => do + let g ← mkAppM ``Global.mk #[toExpr i.getId] + let bindType ← mkAppM ``Prod #[mkConst ``Local, mkConst ``Typ] + mkAppM ``Function.mk #[g, ← mkListLit bindType [], ← elabTyp ty, ← elabTrm t] + | `(function| fn $i:ident($b:bind $[, $bs:bind]*) -> $ty:typ {$t:trm}) => do + let g ← mkAppM ``Global.mk #[toExpr i.getId] + let bindType ← mkAppM ``Prod #[mkConst ``Local, mkConst ``Typ] + mkAppM ``Function.mk + #[g, ← elabListCore b bs elabBind bindType, ← elabTyp ty, ← elabTrm t] + | stx => throw $ .error stx "Invalid syntax for function" + +declare_syntax_cat declaration +syntax function : declaration +syntax data_type : declaration + +def accElabDeclarations (declarations : (Array Expr × Array Expr)) + (stx : TSyntax `declaration) : TermElabM (Array Expr × Array Expr) := + let (dataTypes, functions) := declarations + match stx with + | `(declaration| $f:function) => do + pure (dataTypes, functions.push $ ← elabFunction f) + | `(declaration| $d:data_type) => do + pure (dataTypes.push $ ← elabDataType d, functions) + | stx => throw $ .error stx "Invalid syntax for declaration" + +declare_syntax_cat toplevel +syntax declaration* : toplevel + +def elabToplevel : ElabStxCat `toplevel + | `(toplevel| $[$ds:declaration]*) => do + let (dataTypes, functions) ← ds.foldlM (init := default) accElabDeclarations + mkAppM ``Toplevel.mk #[ + ← mkListLit (mkConst ``DataType) dataTypes.toList, + ← mkListLit (mkConst ``Function) functions.toList, + ] + | stx => throw $ .error stx "Invalid syntax for toplevel" + +elab "⟦" t:toplevel "⟧" : term => elabToplevel t + +end Aiur diff --git a/Ix/Aiur2/Simple.lean b/Ix/Aiur2/Simple.lean new file mode 100644 index 0000000..c0a5cf5 --- /dev/null +++ b/Ix/Aiur2/Simple.lean @@ -0,0 +1,45 @@ +import Ix.Aiur2.Match +import Ix.Aiur2.Check + +namespace Aiur + +/-- This temporary variable can only be used when it does not shadow any other internal. -/ +private abbrev tmpVar := Local.idx 0 + +partial def simplifyTerm (decls : Decls) : Term → Term + | .let var@(.var _) val body => .let var (recr val) (recr body) + -- NOTE: This would not be safe in case Aiur allows side-effects. + -- A sequencing operation would be needed. + | .let .wildcard _ body => recr body + | .let pat val body => + let mtch := .match (.var tmpVar) [(pat, body)] + .let (.var tmpVar) (recr val) (recr mtch) + | .match term branches => + let (tree, _diag) := runMatchCompiler decls term branches + match decisionToTerm tree with + | some term => term + | none => unreachable! + | .ret r => .ret (recr r) + | .app global args => .app global (args.map recr) + | .data (.tuple args) => .data (.tuple (args.map recr)) + | t => t +where + recr := simplifyTerm decls + +def Toplevel.checkAndSimplify (toplevel : Toplevel) : Except CheckError TypedDecls := do + let decls ← toplevel.mkDecls + wellFormedDecls decls + -- TODO: do not duplicate type inference. I.e. do simplification on typed expressions + toplevel.functions.forM fun function => do + let _ ← (checkFunction function) (getFunctionContext function decls) + let decls := decls.map fun decl => match decl with + | .function f => .function { f with body := simplifyTerm decls f.body } + | _ => decl + decls.foldlM (init := default) fun typedDecls (name, decl) => match decl with + | .constructor d c => pure $ typedDecls.insert name (.constructor d c) + | .dataType d => pure $ typedDecls.insert name (.dataType d) + | .function f => do + let f ← (checkFunction f) (getFunctionContext f decls) + pure $ typedDecls.insert name (.function f) + +end Aiur diff --git a/Ix/Aiur2/Term.lean b/Ix/Aiur2/Term.lean new file mode 100644 index 0000000..42e83c7 --- /dev/null +++ b/Ix/Aiur2/Term.lean @@ -0,0 +1,204 @@ +import Std.Data.HashSet.Basic +import Ix.Aiur2.Goldilocks +import Ix.IndexMap + +namespace Aiur + +inductive Local + | str : String → Local + | idx : Nat → Local + deriving Repr, BEq, Hashable + +structure Global where + toName : Lean.Name + deriving Repr, BEq, Inhabited + +instance : EquivBEq Global where + symm {_ _} h := by rw [BEq.beq] at h ⊢; exact BEq.symm h + trans {_ _ _} h₁ h₂ := by rw [BEq.beq] at h₁ h₂ ⊢; exact BEq.trans h₁ h₂ + refl {_} := by rw [BEq.beq]; apply BEq.refl + +instance : Hashable Global where + hash a := hash a.toName + +instance : LawfulHashable Global where + hash_eq a b h := LawfulHashable.hash_eq a.toName b.toName h + +instance : ToString Global where + toString g := g.toName.toString + +def Global.init (limb : String) : Global := + ⟨.mkSimple limb⟩ + +def Global.pushNamespace (global : Global) (limb : String) : Global := + ⟨global.toName.mkStr limb⟩ + +def Global.popNamespace (global : Global) : Option (String × Global) := + match global.toName with + | .str tail head => some (head, ⟨tail⟩) + | _ => none + +inductive Pattern + | var : Local → Pattern + | wildcard : Pattern + | ref : Global → List Pattern → Pattern + | field : G → Pattern + | tuple : Array Pattern → Pattern + | or : Pattern → Pattern → Pattern + deriving Repr, BEq, Hashable, Inhabited + +inductive Typ where + | field + | tuple : Array Typ → Typ + | pointer : Typ → Typ + | dataType : Global → Typ + | function : List Typ → Typ → Typ + deriving Repr, BEq, Hashable, Inhabited + +mutual + +inductive Term + | var : Local → Term + | ref : Global → Term + | data : Data → Term + | ret : Term → Term + | let : Pattern → Term → Term → Term + | match : Term → List (Pattern × Term) → Term + | app : Global → List Term → Term + | add : Term → Term → Term + | sub : Term → Term → Term + | mul : Term → Term → Term + | get : Term → Nat → Term + | slice : Term → Nat → Nat → Term + | store : Term → Term + | load : Term → Term + | ptrVal : Term → Term + | ann : Typ → Term → Term + deriving Repr, BEq, Hashable, Inhabited + +inductive Data + | field : G → Data + | tuple : Array Term → Data + deriving Repr + +end + +inductive ContextualType + | evaluates : Typ → ContextualType + | escapes : ContextualType + deriving Repr, BEq, Inhabited + +def ContextualType.unwrap : ContextualType → Typ +| .escapes => panic! "term should not escape" +| .evaluates typ => typ + +def ContextualType.unwrapOr : ContextualType → Typ → Typ +| .escapes => fun typ => typ +| .evaluates typ => fun _ => typ + +mutual +inductive TypedTermInner + | var : Local → TypedTermInner + | ref : Global → TypedTermInner + | data : TypedData → TypedTermInner + | ret : TypedTerm → TypedTermInner + | let : Pattern → TypedTerm → TypedTerm → TypedTermInner + | match : TypedTerm → List (Pattern × TypedTerm) → TypedTermInner + | app : Global → List TypedTerm → TypedTermInner + | add : TypedTerm → TypedTerm → TypedTermInner + | sub : TypedTerm → TypedTerm → TypedTermInner + | mul : TypedTerm → TypedTerm → TypedTermInner + | get : TypedTerm → Nat → TypedTermInner + | slice : TypedTerm → Nat → Nat → TypedTermInner + | store : TypedTerm → TypedTermInner + | load : TypedTerm → TypedTermInner + | ptrVal : TypedTerm → TypedTermInner + deriving Repr, Inhabited + +structure TypedTerm where + typ : ContextualType + inner : TypedTermInner + deriving Repr, Inhabited + +inductive TypedData + | field : G → TypedData + | tuple : Array TypedTerm → TypedData + deriving Repr + +end + +structure Constructor where + nameHead : String + argTypes : List Typ + deriving Repr, BEq, Inhabited + +structure DataType where + name : Global + constructors : List Constructor + deriving Repr, BEq, Inhabited + +structure Function where + name : Global + inputs : List (Local × Typ) + output : Typ + body : Term + deriving Repr + +structure Toplevel where + dataTypes : List DataType + functions : List Function + deriving Repr + +def Toplevel.getFuncIdx (toplevel : Toplevel) (funcName : Lean.Name) : Option Nat := do + toplevel.functions.findIdx? fun function => function.name.toName == funcName + +inductive Declaration + | function : Function → Declaration + | dataType : DataType → Declaration + | constructor : DataType → Constructor → Declaration + deriving Repr, Inhabited + +abbrev Decls := IndexMap Global Declaration + +structure TypedFunction where + name : Global + inputs : List (Local × Typ) + output : Typ + body : TypedTerm + deriving Repr + +inductive TypedDeclaration + | function : TypedFunction → TypedDeclaration + | dataType : DataType → TypedDeclaration + | constructor : DataType → Constructor → TypedDeclaration + deriving Repr, Inhabited + +abbrev TypedDecls := IndexMap Global TypedDeclaration + +mutual + +open Std (HashSet) + +partial def Typ.size (decls : TypedDecls) (visited : HashSet Global := {}) : Typ → Nat + | Typ.field .. => 1 + | Typ.pointer .. => 1 + | Typ.function .. => 1 + | Typ.tuple ts => ts.foldl (init := 0) (fun acc t => acc + t.size decls visited) + | Typ.dataType g => match decls.getByKey g with + | some (.dataType data) => data.size decls visited + | _ => panic! "impossible case" + +partial def Constructor.size (decls : TypedDecls) (visited : HashSet Global := {}) (c : Constructor) : Nat := + c.argTypes.foldl (λ acc t => acc + t.size decls visited) 0 + +partial def DataType.size (dt : DataType) (decls : TypedDecls) (visited : HashSet Global := {}) : Nat := + if visited.contains dt.name then + panic! s!"cycle detected at datatype `{dt.name}`" + else + let visited := visited.insert dt.name + let ctorSizes := dt.constructors.map (Constructor.size decls visited) + let maxFields := ctorSizes.foldl max 0 + maxFields + 1 +end + +end Aiur diff --git a/Tests/Aiur2.lean b/Tests/Aiur2.lean new file mode 100644 index 0000000..eeb76d8 --- /dev/null +++ b/Tests/Aiur2.lean @@ -0,0 +1,154 @@ +import LSpec +import Ix.Aiur2.Meta +import Ix.Aiur2.Simple +import Ix.Aiur2.Compile + +open LSpec + +def toplevel := ⟦ + fn id(n: G) -> G { + n + } + + fn proj1(a: G, _b: G) -> G { + a + } + + fn sum(x: G, y: G) -> G { + add(x, y) + } + + fn prod(x: G, y: G) -> G { + mul(x, y) + } + + fn store_and_load(x: G) -> G { + load(store(x)) + } + + enum Nat { + Zero, + Succ(&Nat) + } + + fn even(m: Nat) -> G { + match m { + Nat.Zero => 1, + Nat.Succ(m) => odd(load(m)), + } + } + + fn odd(m: Nat) -> G { + match m { + Nat.Zero => 0, + Nat.Succ(m) => even(load(m)), + } + } + + fn is_0_even() -> G { + even(Nat.Zero) + } + + fn is_1_even() -> G { + even(Nat.Succ(store(Nat.Zero))) + } + + fn is_2_even() -> G { + even(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))) + } + + fn is_3_even() -> G { + even(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))))) + } + + fn is_4_even() -> G { + even(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))))))) + } + + fn is_0_odd() -> G { + odd(Nat.Zero) + } + + fn is_1_odd() -> G { + odd(Nat.Succ(store(Nat.Zero))) + } + + fn is_2_odd() -> G { + odd(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))) + } + + fn is_3_odd() -> G { + odd(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))))) + } + + fn is_4_odd() -> G { + odd(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Succ(store(Nat.Zero))))))))) + } + + fn factorial(n: G) -> G { + match n { + 0 => 1, + _ => mul(n, factorial(sub(n, 1))), + } + } + + fn fibonacci(n: G) -> G { + match n { + 0 => 1, + _ => + let n_minus_1 = sub(n, 1); + match n_minus_1 { + 0 => 1, + _ => + let n_minus_2 = sub(n_minus_1, 1); + add(fibonacci(n_minus_1), fibonacci(n_minus_2)), + }, + } + } + + fn slice_and_get(as: (G, G, G, G)) -> G { + get(slice(as, 1, 4), 2) + } +⟧ + +structure TestCase where + functionName : Lean.Name + input : Array Aiur.G + expectedOutput : Array Aiur.G + +def testCases : List TestCase := [ + ⟨`id, #[42], #[42]⟩, + ⟨`proj1, #[42, 64], #[42]⟩, + ⟨`sum, #[3, 5], #[8]⟩, + ⟨`prod, #[3, 5], #[15]⟩, + ⟨`store_and_load, #[42], #[42]⟩, + ⟨`is_0_even, #[], #[1]⟩, + ⟨`is_1_even, #[], #[0]⟩, + ⟨`is_2_even, #[], #[1]⟩, + ⟨`is_3_even, #[], #[0]⟩, + ⟨`is_4_even, #[], #[1]⟩, + ⟨`is_0_odd, #[], #[0]⟩, + ⟨`is_1_odd, #[], #[1]⟩, + ⟨`is_2_odd, #[], #[0]⟩, + ⟨`is_3_odd, #[], #[1]⟩, + ⟨`is_4_odd, #[], #[0]⟩, + ⟨`factorial, #[5], #[120]⟩, + ⟨`fibonacci, #[0], #[1]⟩, + ⟨`fibonacci, #[1], #[1]⟩, + ⟨`fibonacci, #[6], #[13]⟩, + ⟨`slice_and_get, #[1, 2, 3, 4], #[4]⟩, + ] + +def aiurTest : TestSeq := + withExceptOk "Check and simplification works" toplevel.checkAndSimplify fun decls => + let bytecodeToplevel := decls.compile + let runTestCase := fun testCase => + let functionName := testCase.functionName + let funIdx := toplevel.getFuncIdx functionName |>.get! + let output := bytecodeToplevel.executeTest funIdx testCase.input + test s!"Result of {functionName} with arguments {testCase.input} is correct" + (output == testCase.expectedOutput) + testCases.foldl (init := .done) fun tSeq testCase => + tSeq ++ runTestCase testCase + +def Tests.Aiur.suite := [aiurTest] diff --git a/Tests/Main.lean b/Tests/Main.lean index 46bf275..a939bc9 100644 --- a/Tests/Main.lean +++ b/Tests/Main.lean @@ -1,4 +1,4 @@ -import Tests.Aiur +import Tests.Aiur2 import Tests.Archon import Tests.FFIConsistency import Tests.ByteArray diff --git a/c/aiur.c b/c/aiur.c new file mode 100644 index 0000000..00ed809 --- /dev/null +++ b/c/aiur.c @@ -0,0 +1,13 @@ +#include "lean/lean.h" +#include "rust.h" + +extern lean_obj_res c_rs_toplevel_execute_test( + b_lean_obj_arg toplevel, + b_lean_obj_arg fun_idx, + b_lean_obj_arg args, + size_t output_size +) { + lean_obj_res output = lean_alloc_array(output_size, output_size); + rs_toplevel_execute_test(toplevel, fun_idx, args, output); + return output; +} diff --git a/c/rust.h b/c/rust.h index e649e60..1796e52 100644 --- a/c/rust.h +++ b/c/rust.h @@ -2,6 +2,10 @@ #include "lean/lean.h" +void rs_toplevel_execute_test( + b_lean_obj_arg, b_lean_obj_arg, b_lean_obj_arg, lean_obj_arg +); + typedef struct { bool is_ok; void *data; diff --git a/src/aiur2/bytecode.rs b/src/aiur2/bytecode.rs new file mode 100644 index 0000000..e075047 --- /dev/null +++ b/src/aiur2/bytecode.rs @@ -0,0 +1,52 @@ +// TODO: remove +#![allow(dead_code)] + +use indexmap::IndexMap; +use p3_goldilocks::Goldilocks as G; +use rustc_hash::FxBuildHasher; + +pub struct Toplevel { + pub(crate) functions: Vec, + pub(crate) memory_widths: Vec, +} + +pub struct Function { + pub(crate) input_size: usize, + pub(crate) output_size: usize, + pub(crate) body: Block, + pub(crate) circuit_layout: CircuitLayout, +} + +pub struct CircuitLayout { + pub(crate) selectors: usize, + pub(crate) auxiliaries: usize, + pub(crate) shared_constraints: usize, +} + +pub type FxIndexMap = IndexMap; + +pub struct Block { + pub(crate) ops: Vec, + pub(crate) ctrl: Ctrl, + pub(crate) min_sel_included: SelIdx, + pub(crate) max_sel_excluded: SelIdx, +} + +pub enum Op { + Const(G), + Add(ValIdx, ValIdx), + Sub(ValIdx, ValIdx), + Mul(ValIdx, ValIdx), + Call(FunIdx, Vec), + Store(Vec), + Load(usize, ValIdx), +} + +pub enum Ctrl { + Match(ValIdx, FxIndexMap, Option>), + Return(SelIdx, Vec), +} + +pub type SelIdx = usize; +pub type ValIdx = usize; +pub type FunIdx = usize; diff --git a/src/aiur2/execute.rs b/src/aiur2/execute.rs new file mode 100644 index 0000000..2f67159 --- /dev/null +++ b/src/aiur2/execute.rs @@ -0,0 +1,180 @@ +use p3_field::{PrimeCharacteristicRing, PrimeField64}; +use p3_goldilocks::Goldilocks as G; + +use crate::aiur2::bytecode::{Ctrl, FunIdx, Function, FxIndexMap, Op, Toplevel}; + +pub struct QueryResult { + pub(crate) output: Vec, + pub(crate) multiplicity: G, +} + +pub type QueryMap = FxIndexMap, QueryResult>; + +pub struct QueryRecord { + pub(crate) function_queries: Vec, + pub(crate) memory_queries: FxIndexMap, +} + +impl QueryRecord { + fn new(toplevel: &Toplevel) -> Self { + let function_queries = toplevel + .functions + .iter() + .map(|_| QueryMap::default()) + .collect(); + let memory_queries = toplevel + .memory_widths + .iter() + .map(|width| (*width, QueryMap::default())) + .collect(); + Self { + function_queries, + memory_queries, + } + } +} + +impl Toplevel { + pub fn execute(&self, fun_idx: FunIdx, args: Vec) -> QueryRecord { + let mut record = QueryRecord::new(self); + let function = &self.functions[fun_idx]; + function.execute(fun_idx, args, self, &mut record); + record + } +} + +enum ExecEntry<'a> { + Op(&'a Op), + Ctrl(&'a Ctrl), +} + +struct CallerState { + fun_idx: FunIdx, + map: Vec, +} + +impl Function { + fn execute( + &self, + mut fun_idx: FunIdx, + mut map: Vec, + toplevel: &Toplevel, + record: &mut QueryRecord, + ) { + let mut exec_entries_stack = vec![]; + let mut callers_states_stack = vec![]; + macro_rules! push_block_exec_entries { + ($block:expr) => { + exec_entries_stack.push(ExecEntry::Ctrl(&$block.ctrl)); + exec_entries_stack.extend($block.ops.iter().rev().map(ExecEntry::Op)); + }; + } + push_block_exec_entries!(&self.body); + while let Some(exec_entry) = exec_entries_stack.pop() { + match exec_entry { + ExecEntry::Op(Op::Const(c)) => map.push(*c), + ExecEntry::Op(Op::Add(a, b)) => { + let a = map[*a]; + let b = map[*b]; + map.push(a + b); + } + ExecEntry::Op(Op::Sub(a, b)) => { + let a = map[*a]; + let b = map[*b]; + map.push(a - b); + } + ExecEntry::Op(Op::Mul(a, b)) => { + let a = map[*a]; + let b = map[*b]; + map.push(a * b); + } + ExecEntry::Op(Op::Call(callee_idx, args)) => { + let args = args.iter().map(|i| map[*i]).collect(); + if let Some(result) = record.function_queries[*callee_idx].get_mut(&args) { + result.multiplicity += G::ONE; + map.extend(result.output.clone()); + } else { + let saved_map = std::mem::replace(&mut map, args); + // Save the current caller state. + callers_states_stack.push(CallerState { + fun_idx, + map: saved_map, + }); + // Prepare outer variables to go into the new func scope. + fun_idx = *callee_idx; + let function = &toplevel.functions[fun_idx]; + push_block_exec_entries!(&function.body); + } + } + ExecEntry::Op(Op::Store(values)) => { + let values = values.iter().map(|v| map[*v]).collect::>(); + let width = values.len(); + let memory_queries = record + .memory_queries + .get_mut(&width) + .expect("Invalid memory width"); + if let Some(result) = memory_queries.get_mut(&values) { + result.multiplicity += G::ONE; + map.extend(&result.output); + } else { + let ptr = G::from_usize(memory_queries.len()); + let result = QueryResult { + output: vec![ptr], + multiplicity: G::ONE, + }; + memory_queries.insert(values, result); + map.push(ptr); + } + } + ExecEntry::Op(Op::Load(width, ptr)) => { + let memory_queries = record + .memory_queries + .get_mut(width) + .expect("Invalid memory width"); + let ptr = &map[*ptr]; + let ptr_u64 = ptr.as_canonical_u64(); + let ptr_usize = usize::try_from(ptr_u64).expect("Pointer is too big"); + let (args, result) = memory_queries + .get_index_mut(ptr_usize) + .expect("Unbound pointer"); + result.multiplicity += G::ONE; + map.extend(args); + } + ExecEntry::Ctrl(Ctrl::Match(val_idx, cases, default)) => { + let val = &map[*val_idx]; + if let Some(block) = cases.get(val) { + push_block_exec_entries!(block); + } else { + let default = default.as_ref().expect("No match"); + push_block_exec_entries!(default); + } + } + ExecEntry::Ctrl(Ctrl::Return(_, output)) => { + // Register the query. + let input_size = toplevel.functions[fun_idx].input_size; + let args = map[..input_size].to_vec(); + let output = output.iter().map(|i| map[*i]).collect::>(); + let result = QueryResult { + output: output.clone(), + multiplicity: G::ONE, + }; + record.function_queries[fun_idx].insert(args, result); + if let Some(CallerState { + fun_idx: caller_idx, + map: caller_map, + }) = callers_states_stack.pop() + { + // Recover the state of the caller. + fun_idx = caller_idx; + map = caller_map; + map.extend(output); + } else { + // No outer caller. About to exit. + assert!(exec_entries_stack.is_empty()); + break; + } + } + } + } + } +} diff --git a/src/aiur2/mod.rs b/src/aiur2/mod.rs new file mode 100644 index 0000000..72c93fe --- /dev/null +++ b/src/aiur2/mod.rs @@ -0,0 +1,2 @@ +pub mod bytecode; +pub mod execute; diff --git a/src/lean/array.rs b/src/lean/array.rs index 7d0b4b1..7b71ad4 100644 --- a/src/lean/array.rs +++ b/src/lean/array.rs @@ -28,4 +28,10 @@ impl LeanArrayObject { pub fn to_vec(&self, map_fn: fn(*const c_void) -> T) -> Vec { self.data().iter().map(|ptr| map_fn(*ptr)).collect() } + + pub fn set_data(&mut self, data: &[*const c_void]) { + assert!(self.m_capacity >= data.len()); + self.m_data.copy_from_slice(data); + self.m_size = data.len(); + } } diff --git a/src/lean/ffi/aiur2/mod.rs b/src/lean/ffi/aiur2/mod.rs new file mode 100644 index 0000000..05fb861 --- /dev/null +++ b/src/lean/ffi/aiur2/mod.rs @@ -0,0 +1 @@ +pub mod toplevel; diff --git a/src/lean/ffi/aiur2/toplevel.rs b/src/lean/ffi/aiur2/toplevel.rs new file mode 100644 index 0000000..8397544 --- /dev/null +++ b/src/lean/ffi/aiur2/toplevel.rs @@ -0,0 +1,204 @@ +// TODO: remove +#![allow(dead_code)] + +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks as G; +use std::{ffi::c_void, mem::transmute}; + +use crate::{ + aiur2::bytecode::{Block, CircuitLayout, Ctrl, Function, FxIndexMap, Op, Toplevel, ValIdx}, + lean::{ + array::LeanArrayObject, + ctor::LeanCtorObject, + ffi::{as_ref_unsafe, lean_is_scalar}, + }, + lean_unbox, +}; + +fn lean_unbox_nat_as_usize(ptr: *const c_void) -> usize { + assert!(lean_is_scalar(ptr)); + lean_unbox!(usize, ptr) +} + +fn lean_unbox_nat_as_g(ptr: *const c_void) -> G { + assert!(lean_is_scalar(ptr)); + G::from_usize(lean_unbox!(usize, ptr)) +} + +fn lean_ptr_to_vec_val_idx(ptr: *const c_void) -> Vec { + let array: &LeanArrayObject = as_ref_unsafe(ptr.cast()); + array.to_vec(lean_unbox_nat_as_usize) +} + +fn lean_ptr_to_op(ptr: *const c_void) -> Op { + let ctor: &LeanCtorObject = as_ref_unsafe(ptr.cast()); + match ctor.tag() { + 0 => { + let [const_val_ptr] = ctor.objs(); + Op::Const(lean_unbox_nat_as_g(const_val_ptr)) + } + 1 => { + let [a_ptr, b_ptr] = ctor.objs(); + Op::Add( + lean_unbox_nat_as_usize(a_ptr), + lean_unbox_nat_as_usize(b_ptr), + ) + } + 2 => { + let [a_ptr, b_ptr] = ctor.objs(); + Op::Sub( + lean_unbox_nat_as_usize(a_ptr), + lean_unbox_nat_as_usize(b_ptr), + ) + } + 3 => { + let [a_ptr, b_ptr] = ctor.objs(); + Op::Mul( + lean_unbox_nat_as_usize(a_ptr), + lean_unbox_nat_as_usize(b_ptr), + ) + } + 4 => { + let [fun_idx_ptr, val_idxs_ptr] = ctor.objs(); + let fun_idx = lean_unbox_nat_as_usize(fun_idx_ptr); + let val_idxs = lean_ptr_to_vec_val_idx(val_idxs_ptr); + Op::Call(fun_idx, val_idxs) + } + 5 => { + let [val_idxs_ptr] = ctor.objs(); + Op::Store(lean_ptr_to_vec_val_idx(val_idxs_ptr)) + } + 6 => { + let [width_ptr, val_idx_ptr] = ctor.objs(); + Op::Load( + lean_unbox_nat_as_usize(width_ptr), + lean_unbox_nat_as_usize(val_idx_ptr), + ) + } + _ => unreachable!(), + } +} + +fn lean_ptr_to_g_block_pair(ptr: *const c_void) -> (G, Block) { + let ctor: &LeanCtorObject = as_ref_unsafe(ptr.cast()); + let [g_ptr, block_ptr] = ctor.objs(); + let g = lean_unbox_nat_as_g(g_ptr); + let block = lean_ctor_to_block(as_ref_unsafe(block_ptr.cast())); + (g, block) +} + +fn lean_ctor_to_ctrl(ctor: &LeanCtorObject) -> Ctrl { + match ctor.tag() { + 0 => { + let [val_idx_ptr, cases_ptr, default_ptr] = ctor.objs(); + let val_idx = lean_unbox_nat_as_usize(val_idx_ptr); + let cases_array: &LeanArrayObject = as_ref_unsafe(cases_ptr.cast()); + let vec_cases = cases_array.to_vec(lean_ptr_to_g_block_pair); + let cases = FxIndexMap::from_iter(vec_cases); + let default = if lean_is_scalar(default_ptr) { + None + } else { + let default_ctor: &LeanCtorObject = as_ref_unsafe(default_ptr.cast()); + let [block_ptr] = default_ctor.objs(); + let block = lean_ctor_to_block(as_ref_unsafe(block_ptr.cast())); + Some(Box::new(block)) + }; + Ctrl::Match(val_idx, cases, default) + } + 1 => { + let [sel_idx_ptr, val_idxs_ptr] = ctor.objs(); + let sel_idx = lean_unbox_nat_as_usize(sel_idx_ptr); + let val_idxs = lean_ptr_to_vec_val_idx(val_idxs_ptr); + Ctrl::Return(sel_idx, val_idxs) + } + _ => unreachable!(), + } +} + +fn lean_ctor_to_block(ctor: &LeanCtorObject) -> Block { + let [ + ops_ptr, + ctrl_ptr, + min_sel_included_ptr, + max_sel_excluded_ptr, + ] = ctor.objs(); + let ops_array: &LeanArrayObject = as_ref_unsafe(ops_ptr.cast()); + let ops = ops_array.to_vec(lean_ptr_to_op); + let ctrl = lean_ctor_to_ctrl(as_ref_unsafe(ctrl_ptr.cast())); + let min_sel_included = lean_unbox_nat_as_usize(min_sel_included_ptr); + let max_sel_excluded = lean_unbox_nat_as_usize(max_sel_excluded_ptr); + Block { + ops, + ctrl, + min_sel_included, + max_sel_excluded, + } +} + +fn lean_ctor_to_circuit_layout(ctor: &LeanCtorObject) -> CircuitLayout { + let [selectors_ptr, auxiliaries_ptr, shared_constraints_ptr] = ctor.objs(); + CircuitLayout { + selectors: lean_unbox_nat_as_usize(selectors_ptr), + auxiliaries: lean_unbox_nat_as_usize(auxiliaries_ptr), + shared_constraints: lean_unbox_nat_as_usize(shared_constraints_ptr), + } +} + +fn lean_ptr_to_function(ptr: *const c_void) -> Function { + let ctor: &LeanCtorObject = as_ref_unsafe(ptr.cast()); + let [ + input_size_ptr, + output_size_ptr, + body_ptr, + circuit_layout_ptr, + ] = ctor.objs(); + let input_size = lean_unbox_nat_as_usize(input_size_ptr); + let output_size = lean_unbox_nat_as_usize(output_size_ptr); + let body = lean_ctor_to_block(as_ref_unsafe(body_ptr.cast())); + let circuit_layout = lean_ctor_to_circuit_layout(as_ref_unsafe(circuit_layout_ptr.cast())); + Function { + input_size, + output_size, + body, + circuit_layout, + } +} + +fn lean_ctor_to_toplevel(ctor: &LeanCtorObject) -> Toplevel { + let [functions_ptr, memory_widths_ptr] = ctor.objs(); + let functions_array: &LeanArrayObject = as_ref_unsafe(functions_ptr.cast()); + let functions = functions_array.to_vec(lean_ptr_to_function); + let memory_widths_array: &LeanArrayObject = as_ref_unsafe(memory_widths_ptr.cast()); + let memory_widths = memory_widths_array.to_vec(lean_unbox_nat_as_usize); + Toplevel { + functions, + memory_widths, + } +} + +#[unsafe(no_mangle)] +extern "C" fn rs_toplevel_execute_test( + toplevel: &LeanCtorObject, + fun_idx: *const c_void, + args: &LeanArrayObject, + output: &mut LeanArrayObject, +) { + let fun_idx = lean_unbox_nat_as_usize(fun_idx); + let toplevel = lean_ctor_to_toplevel(toplevel); + let args = args.to_vec(lean_unbox_nat_as_g); + let record = toplevel.execute(fun_idx, args.clone()); + let output_values = record.function_queries[fun_idx] + .get(&args) + .map(|res| &res.output) + .unwrap(); + let boxed_values = output_values + .iter() + .map(|g| { + let g_u64 = unsafe { transmute::(*g) }; + let g_usize = usize::try_from(g_u64).unwrap(); + let g_boxed = (g_usize << 1) | 1; + g_boxed as *const c_void + }) + .collect::>(); + output.set_data(&boxed_values); +} diff --git a/src/lean/ffi/mod.rs b/src/lean/ffi/mod.rs index 545a9ac..aa75935 100644 --- a/src/lean/ffi/mod.rs +++ b/src/lean/ffi/mod.rs @@ -1,3 +1,4 @@ +pub mod aiur2; pub mod archon; pub mod binius; pub mod byte_array; diff --git a/src/lib.rs b/src/lib.rs index cf6e8f5..c4b7c92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod aiur; +pub mod aiur2; pub mod archon; pub mod lean;