Skip to content

[UPLC] [Optimization] Improve case-of-case #7210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion plutus-core/plutus-core/src/PlutusCore/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ type Compiling m uni fun name a =
( ToBuiltinMeaning uni fun
, MonadQuote m
, CaseBuiltin uni
, GEq uni
, Closed uni
, GEq uni
, Everywhere uni Eq
, Everywhere uni Hashable
, HasUnique name TermUnique
, Ord name
, Typeable name
Expand Down
2 changes: 1 addition & 1 deletion plutus-core/testlib/PlutusCore/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ instance
( TPLC.Typecheckable uni fun
, CaseBuiltin uni
, Hashable fun
, TPLC.GEq uni, TPLC.Closed uni, TPLC.Everywhere uni Eq
, TPLC.GEq uni, TPLC.Closed uni, TPLC.Everywhere uni Eq, TPLC.Everywhere uni Hashable
)
=> ToUPlc (TPLC.Program TPLC.TyName UPLC.Name uni fun ()) uni fun where
toUPlc =
Expand Down
3 changes: 2 additions & 1 deletion plutus-core/testlib/PlutusIR/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ instance

instance
( PLC.GEq uni
, uni `PLC.Everywhere` Eq
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
, PLC.PrettyUni uni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,26 @@ into

This is always an improvement.
-}
module UntypedPlutusCore.Transform.CaseOfCase (caseOfCase) where
module UntypedPlutusCore.Transform.CaseOfCase (caseOfCase, processTerm, annotateIsDuplicatedOn) where

import PlutusPrelude

import PlutusCore qualified as PLC
import PlutusCore.Builtin (CaseBuiltin (..))
import PlutusCore.MkPlc (mkIterApp)
import UntypedPlutusCore.Core
import UntypedPlutusCore.Size (termSize)
import UntypedPlutusCore.Transform.CaseReduce qualified as CaseReduce
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseOfCase), SimplifierT,
recordSimplification)

import Control.Lens
import Data.List (nub)
import Data.Hashable (Hashable)
import Data.HashMap.Strict qualified as HashMap

caseOfCase
:: ( fun ~ PLC.DefaultFun, Monad m, CaseBuiltin uni
, PLC.GEq uni, PLC.Closed uni, uni `PLC.Everywhere` Eq
, PLC.Closed uni, PLC.GEq uni, uni `PLC.Everywhere` Eq, uni `PLC.Everywhere` Hashable
)
=> Term name uni fun a
-> SimplifierT name uni fun a m (Term name uni fun a)
Expand All @@ -60,9 +62,15 @@ caseOfCase term = do
recordSimplification term CaseOfCase result
return result

-- >>> annotateIsDuplicatedOn id "abacdeffdbdg"
-- [('a',True),('b',True),('a',True),('c',False),('d',True),('e',False),('f',True),('f',True),('d',True),('b',True),('d',True),('g',False)]
annotateIsDuplicatedOn :: (Functor f, Foldable f, Hashable b) => (a -> b) -> f a -> f (a, Bool)
annotateIsDuplicatedOn f xs = fmap (\x -> (x, duplMap HashMap.! f x)) xs where
duplMap = HashMap.fromListWith (\_ _ -> True) . map (\x -> (f x, False)) $ toList xs

processTerm
:: ( fun ~ PLC.DefaultFun, CaseBuiltin uni
, PLC.GEq uni, PLC.Closed uni, uni `PLC.Everywhere` Eq
, PLC.Closed uni, PLC.GEq uni, uni `PLC.Everywhere` Eq, uni `PLC.Everywhere` Hashable
)
=> Term name uni fun a -> Term name uni fun a
processTerm = \case
Expand All @@ -88,11 +96,13 @@ processTerm = \case
original
(Case annInner scrut)
(do
constrs <- for altsInner $ \case
constrsDupl <- fmap (annotateIsDuplicatedOn fst) . for altsInner $ \case
c@(Constr _ i _) -> Just (Left i, c)
c@(Constant _ val) -> Just (Right val, c)
_ -> Nothing
-- See Note [Case-of-case and duplicating code].
guard $ length (nub . toList $ fmap fst constrs) == length constrs
pure $ constrs <&> \(_, c) -> CaseReduce.processTerm $ Case annOuter c altsOuter)
for constrsDupl $ \((_, c), dupl) -> do
let alt = CaseReduce.processTerm $ Case annOuter c altsOuter
guard $ not dupl || termSize alt <= 3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried making it 300, no effect either.

pure alt)
other -> other
16 changes: 16 additions & 0 deletions plutus-tx/src/PlutusTx/Lift.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ safeLiftWith
. ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand Down Expand Up @@ -119,6 +120,7 @@ safeLift
. ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand All @@ -143,6 +145,7 @@ safeLiftUnopt
. ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand Down Expand Up @@ -171,6 +174,7 @@ safeLiftProgram
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand All @@ -194,6 +198,7 @@ safeLiftProgramUnopt
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand All @@ -214,6 +219,7 @@ safeLiftCode
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand All @@ -240,6 +246,7 @@ safeLiftCodeUnopt
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, MonadError (PIR.Error uni fun (Provenance ())) m
, MonadQuote m
, PLC.Typecheckable uni fun
Expand Down Expand Up @@ -278,6 +285,7 @@ lift
, PLC.Typecheckable uni fun
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.CaseBuiltin uni
, Default (PLC.CostingPart uni fun)
, Default (PIR.BuiltinsInfo uni fun)
Expand All @@ -298,6 +306,7 @@ liftUnopt
, PLC.Typecheckable uni fun
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.CaseBuiltin uni
, Default (PLC.CostingPart uni fun)
, Default (PIR.BuiltinsInfo uni fun)
Expand All @@ -316,6 +325,7 @@ liftProgram
, PLC.Typecheckable uni fun
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.CaseBuiltin uni
, Default (PLC.CostingPart uni fun)
, Default (PIR.BuiltinsInfo uni fun)
Expand All @@ -336,6 +346,7 @@ liftProgramUnopt
, PLC.Typecheckable uni fun
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.CaseBuiltin uni
, Default (PLC.CostingPart uni fun)
, Default (PIR.BuiltinsInfo uni fun)
Expand Down Expand Up @@ -372,6 +383,7 @@ liftCode
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, ThrowableBuiltins uni fun
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
Expand All @@ -390,6 +402,7 @@ liftCodeUnopt
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, ThrowableBuiltins uni fun
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
Expand All @@ -406,6 +419,7 @@ liftCodeDef
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, ThrowableBuiltins uni fun
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
Expand All @@ -424,6 +438,7 @@ liftCodeDefUnopt
:: ( Lift.Lift uni a
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, ThrowableBuiltins uni fun
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
Expand Down Expand Up @@ -495,6 +510,7 @@ typeCode
, MonadQuote m
, PLC.GEq uni
, PLC.Everywhere uni Eq
, PLC.Everywhere uni Hashable
, PLC.Typecheckable uni fun
, PLC.CaseBuiltin uni
, PrettyUni uni
Expand Down
Loading