Skip to content

Commit

Permalink
PLT-7745: Basic rewrite rules for builtins
Browse files Browse the repository at this point in the history
Make CommuteFnWithConst transformation a RewriteRule
  • Loading branch information
bezirg committed Oct 27, 2023
1 parent b751871 commit ac26414
Show file tree
Hide file tree
Showing 37 changed files with 275 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@
- True
- Run a simplification pass that removes dead bindings

* - ``simplifier-rewrite``
- Bool
- True
- Run a pass that performs some pre-defined rewrite rules on builtins (similar to GHC's RULES)

* - ``simplifier-unwrap-cancel``
- Bool
Expand Down
5 changes: 5 additions & 0 deletions plutus-core/changelog.d/20231027_133344_bezirg_pir_rewrite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
### Added

- A new pass in the simplifier that rewrites PIR terms given user-provided rules.
It behaves similar to GHC's RULES, but for the PIR language.
By default, a pre-defined set of rules are applied when the PIR simplifier is enabled.
7 changes: 5 additions & 2 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ library plutus-ir
PlutusIR.Transform.Beta
PlutusIR.Transform.CaseOfCase
PlutusIR.Transform.CaseReduce
PlutusIR.Transform.CommuteFnWithConst
PlutusIR.Transform.DeadCode
PlutusIR.Transform.EvaluateBuiltins
PlutusIR.Transform.Inline.CallSiteInline
Expand All @@ -516,6 +515,8 @@ library plutus-ir
PlutusIR.Transform.NonStrict
PlutusIR.Transform.RecSplit
PlutusIR.Transform.Rename
PlutusIR.Transform.RewriteRules
PlutusIR.Transform.RewriteRules.CommuteFnWithConst
PlutusIR.Transform.StrictifyBindings
PlutusIR.Transform.Substitute
PlutusIR.Transform.ThunkRecursions
Expand All @@ -532,6 +533,7 @@ library plutus-ir
PlutusIR.Compiler.Lower
PlutusIR.Compiler.Recursion
PlutusIR.Normalize
PlutusIR.Transform.RewriteRules.DecodeEncodeUtf8

build-depends:
, algebraic-graphs >=0.7
Expand Down Expand Up @@ -588,7 +590,6 @@ test-suite plutus-ir-test
PlutusIR.Transform.Beta.Tests
PlutusIR.Transform.CaseOfCase.Tests
PlutusIR.Transform.CaseReduce.Tests
PlutusIR.Transform.CommuteFnWithConst.Tests
PlutusIR.Transform.DeadCode.Tests
PlutusIR.Transform.EvaluateBuiltins.Tests
PlutusIR.Transform.Inline.Tests
Expand All @@ -598,6 +599,7 @@ test-suite plutus-ir-test
PlutusIR.Transform.NonStrict.Tests
PlutusIR.Transform.RecSplit.Tests
PlutusIR.Transform.Rename.Tests
PlutusIR.Transform.RewriteRules.Tests
PlutusIR.Transform.StrictifyBindings.Tests
PlutusIR.Transform.ThunkRecursions.Tests
PlutusIR.Transform.Unwrap.Tests
Expand All @@ -607,6 +609,7 @@ test-suite plutus-ir-test
build-depends:
, base >=4.9 && <5
, containers
, data-default-class
, flat ^>=0.6
, hashable
, hedgehog
Expand Down
46 changes: 24 additions & 22 deletions plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module PlutusIR.Analysis.Builtins (
BuiltinMatcherLike (..),
bmlSplitMatchContext,
bmlBranchArities,
defaultUniMatcherLike,

BuiltinsInfo (..),
biSemanticsVariant,
biMatcherLike,

builtinArityInfo,

asBuiltinDatatypeMatch,
builtinDatatypeMatchBranchArities) where
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module PlutusIR.Analysis.Builtins
( BuiltinMatcherLike (..)
, bmlSplitMatchContext
, bmlBranchArities
, defaultUniMatcherLike
, BuiltinsInfo (..)
, biSemanticsVariant
, biMatcherLike
, builtinArityInfo
, asBuiltinDatatypeMatch
, builtinDatatypeMatchBranchArities
) where

import Control.Lens hiding (parts)
import Data.Functor (void)
Expand Down Expand Up @@ -44,11 +44,13 @@ data BuiltinsInfo (uni :: Type -> Type) fun = BuiltinsInfo
{ _biSemanticsVariant :: PLC.BuiltinSemanticsVariant fun
, _biMatcherLike :: Map.Map fun (BuiltinMatcherLike uni fun)
}

makeLenses ''BuiltinsInfo

instance (Ord fun, Default (BuiltinSemanticsVariant fun)) => Default (BuiltinsInfo uni fun) where
def = BuiltinsInfo def mempty
instance Default (BuiltinsInfo DefaultUni DefaultFun) where
def = BuiltinsInfo
{ _biSemanticsVariant = def
, _biMatcherLike = defaultUniMatcherLike
}

-- | Get the arity of a builtin function from the 'PLC.BuiltinInfo'.
builtinArityInfo
Expand Down
10 changes: 7 additions & 3 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module PlutusIR.Compiler (
coDoSimplifierInline,
coDoSimplifierEvaluateBuiltins,
coDoSimplifierStrictifyBindings,
coDoSimplifierRewrite,
coInlineHints,
coProfile,
coRelaxedFloatin,
Expand Down Expand Up @@ -58,7 +59,6 @@ import PlutusIR.Error
import PlutusIR.Transform.Beta qualified as Beta
import PlutusIR.Transform.CaseOfCase qualified as CaseOfCase
import PlutusIR.Transform.CaseReduce qualified as CaseReduce
import PlutusIR.Transform.CommuteFnWithConst qualified as CommuteFnWithConst
import PlutusIR.Transform.DeadCode qualified as DeadCode
import PlutusIR.Transform.EvaluateBuiltins qualified as EvaluateBuiltins
import PlutusIR.Transform.Inline.Inline qualified as Inline
Expand All @@ -69,6 +69,7 @@ import PlutusIR.Transform.LetMerge qualified as LetMerge
import PlutusIR.Transform.NonStrict qualified as NonStrict
import PlutusIR.Transform.RecSplit qualified as RecSplit
import PlutusIR.Transform.Rename ()
import PlutusIR.Transform.RewriteRules qualified as RewriteRules
import PlutusIR.Transform.StrictifyBindings qualified as StrictifyBindings
import PlutusIR.Transform.ThunkRecursions qualified as ThunkRec
import PlutusIR.Transform.Unwrap qualified as Unwrap
Expand Down Expand Up @@ -140,7 +141,9 @@ availablePasses =
binfo <- view ccBuiltinsInfo
Inline.inline hints binfo t
)
, Pass "commuteFnWithConst" (onOption coDoSimplifiercommuteFnWithConst) (pure . CommuteFnWithConst.commuteFnWithConst)
, Pass "rewrite rules" (onOption coDoSimplifierRewrite) (\ t -> do
rules <- view ccRewriteRules
RewriteRules.userRewrite rules t)
]

-- | Actual simplifier
Expand All @@ -150,11 +153,12 @@ simplify
simplify = foldl' (>=>) pure (map applyPass availablePasses)

-- | Perform some simplification of a 'Term'.
--
-- NOTE: simplifyTerm requires at least 1 prior dead code elimination pass
simplifyTerm
:: forall m e uni fun a b. (Compiling m e uni fun a, b ~ Provenance a)
=> Term TyName Name uni fun b -> m (Term TyName Name uni fun b)
simplifyTerm = runIfOpts simplify'
-- NOTE: we need at least one pass of dead code elimination
where
simplify' :: Term TyName Name uni fun b -> m (Term TyName Name uni fun b)
simplify' t = do
Expand Down
47 changes: 25 additions & 22 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import PlutusCore.Quote
import PlutusCore.StdLib.Type qualified as Types
import PlutusCore.TypeCheck.Internal qualified as PLC
import PlutusCore.Version qualified as PLC
import PlutusIR.Transform.RewriteRules
import PlutusPrelude

import Control.Monad.Error.Lens (throwing)
Expand Down Expand Up @@ -74,29 +75,29 @@ defaultDatatypeCompilationOpts :: DatatypeCompilationOpts
defaultDatatypeCompilationOpts = DatatypeCompilationOpts SumsOfProducts

data CompilationOpts a = CompilationOpts {
_coOptimize :: Bool
, _coPedantic :: Bool
, _coVerbose :: Bool
, _coDebug :: Bool
, _coDatatypes :: DatatypeCompilationOpts
_coOptimize :: Bool
, _coPedantic :: Bool
, _coVerbose :: Bool
, _coDebug :: Bool
, _coDatatypes :: DatatypeCompilationOpts
-- Simplifier passes
, _coMaxSimplifierIterations :: Int
, _coDoSimplifierUnwrapCancel :: Bool
, _coDoSimplifierCaseReduce :: Bool
, _coDoSimplifiercommuteFnWithConst :: Bool
, _coDoSimplifierBeta :: Bool
, _coDoSimplifierInline :: Bool
, _coDoSimplifierKnownCon :: Bool
, _coDoSimplifierCaseOfCase :: Bool
, _coDoSimplifierEvaluateBuiltins :: Bool
, _coDoSimplifierStrictifyBindings :: Bool
, _coInlineHints :: InlineHints PLC.Name (Provenance a)
, _coMaxSimplifierIterations :: Int
, _coDoSimplifierUnwrapCancel :: Bool
, _coDoSimplifierCaseReduce :: Bool
, _coDoSimplifierRewrite :: Bool
, _coDoSimplifierBeta :: Bool
, _coDoSimplifierInline :: Bool
, _coDoSimplifierKnownCon :: Bool
, _coDoSimplifierCaseOfCase :: Bool
, _coDoSimplifierEvaluateBuiltins :: Bool
, _coDoSimplifierStrictifyBindings :: Bool
, _coInlineHints :: InlineHints PLC.Name (Provenance a)
-- Profiling
, _coProfile :: Bool
, _coRelaxedFloatin :: Bool
, _coCaseOfCaseConservative :: Bool
, _coProfile :: Bool
, _coRelaxedFloatin :: Bool
, _coCaseOfCaseConservative :: Bool
-- | Whether to try and preserve the logging beahviour of the program.
, _coPreserveLogging :: Bool
, _coPreserveLogging :: Bool
} deriving stock (Show)

makeLenses ''CompilationOpts
Expand All @@ -111,7 +112,7 @@ defaultCompilationOpts = CompilationOpts
, _coMaxSimplifierIterations = 12
, _coDoSimplifierUnwrapCancel = True
, _coDoSimplifierCaseReduce = True
, _coDoSimplifiercommuteFnWithConst = True
, _coDoSimplifierRewrite = True
, _coDoSimplifierKnownCon = True
, _coDoSimplifierCaseOfCase = True
, _coDoSimplifierBeta = True
Expand All @@ -132,19 +133,21 @@ data CompilationCtx uni fun a = CompilationCtx {
, _ccTypeCheckConfig :: Maybe (PirTCConfig uni fun)
, _ccBuiltinsInfo :: BuiltinsInfo uni fun
, _ccBuiltinCostModel :: PLC.CostingPart uni fun
, _ccRewriteRules :: RewriteRules uni fun
}

makeLenses ''CompilationCtx

toDefaultCompilationCtx
:: (Ord fun, Default (PLC.BuiltinSemanticsVariant fun), Default (PLC.CostingPart uni fun))
:: (Default (BuiltinsInfo uni fun), Default (PLC.CostingPart uni fun), Default (RewriteRules uni fun))
=> PLC.TypeCheckConfig uni fun
-> CompilationCtx uni fun a
toDefaultCompilationCtx configPlc =
CompilationCtx defaultCompilationOpts noProvenance
(Just $ PirTCConfig configPlc YesEscape)
def
def
def

validateOpts :: Compiling m e uni fun a => PLC.Version -> m ()
validateOpts v = do
Expand Down
1 change: 0 additions & 1 deletion plutus-core/plutus-ir/src/PlutusIR/Subst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ funRes f = \case
TyFun a dom cod -> TyFun a dom <$> funRes f cod
t -> f t

-- TODO: these could be Traversals
-- | Get all the term variables in a term.
vTerm :: Fold (Term tyname name uni fun ann) name
vTerm = termSubtermsDeep . termVars
Expand Down
59 changes: 59 additions & 0 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module PlutusIR.Transform.RewriteRules
( userRewrite
, RewriteRules (..)
, defaultUniRewriteRules
) where

import PlutusCore.Default
import PlutusCore.Name
import PlutusCore.Quote
import PlutusIR as PIR
import PlutusIR.Analysis.VarInfo
import PlutusIR.Transform.RewriteRules.CommuteFnWithConst
import PlutusPrelude

import Control.Lens


-- | Rewrite a `Term` using the given `RewriteRules` (similar to functions of Term -> Term)
-- Normally the rewrite rules are configured at entrypoint time of the compiler.
userRewrite :: ( Semigroup a, t ~ Term tyname name uni fun a
, HasUnique name TermUnique
, HasUnique tyname TypeUnique
, MonadQuote m
)
=> RewriteRules uni fun
-> t
-> m t
userRewrite (RewriteRules rules) t =
-- We collect `VarsInfo` on the whole program term and pass it on as arg to each RewriteRule.
-- This has the limitation that any variables newly-introduced by the rules would
-- not be accounted in `VarsInfo`. This is currently fine, because we only rely on VarsInfo
-- for isPure; isPure is safe w.r.t "open" terms.
let vinfo = termVarInfo t
in transformMOf termSubterms (rules vinfo) t

-- | A bundle of composed `RewriteRules`, to be passed at entrypoint of the compiler.
newtype RewriteRules uni fun = RewriteRules {
unRewriteRules :: forall tyname name m a
. (MonadQuote m, Semigroup a)
=> VarsInfo tyname name uni a
-> PIR.Term tyname name uni fun a
-> m (PIR.Term tyname name uni fun a)
}

-- | The rules for the Default Universe/Builtin.
defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules = RewriteRules $ \ _vinfo ->
-- The rules are composed from left to right.
pure . commuteFnWithConst
-- FIXME: Unfortunately, unicode text is currently broken (at least on plutus-tx level), so
-- we disable this rewrite until fix is in and further tested. See PLT-8314
-- >=> pure . decodeEncodeUtf8

instance Default (RewriteRules DefaultUni DefaultFun) where
def = defaultUniRewriteRules
Loading

0 comments on commit ac26414

Please sign in to comment.