Skip to content

Commit

Permalink
Merge pull request #211 from Plutonomicon/las/optimisations
Browse files Browse the repository at this point in the history
Optimisations
  • Loading branch information
L-as authored Jan 27, 2022
2 parents d753dc3 + 636d12b commit 7c64217
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 38 deletions.
26 changes: 25 additions & 1 deletion Plutarch/DataRepr/Internal/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,34 @@ instance {-# OVERLAPPABLE #-} (BindFields ps bs) => BindFields ((l ':= p) ': ps)
xs <- bindFields @ps @bs (pdropDataRecord (Proxy @1) t')
pure $ HCons (Labeled $ pindexDataRecord (Proxy @0) t') xs

instance (BindFields ps bs) => BindFields (p ': ps) ( 'Skip ': bs) where
instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': ps) ( 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @1) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': ps) ( 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @2) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': p3 ': ps) ( 'Skip ': 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @3) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': p3 ': p4 ': ps) ( 'Skip ': 'Skip ': 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @4) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': p3 ': p4 ': p5 ': ps) ( 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @5) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': p3 ': p4 ': p5 ': p6 ': ps) ( 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @6) t

instance {-# OVERLAPPING #-} (BindFields ps bs) => BindFields (p1 ': p2 ': p3 ': p4 ': p5 ': p6 ': p7 ': ps) ( 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': 'Skip ': bs) where
bindFields t = do
bindFields @ps @bs $ pdropDataRecord (Proxy @7) t

--------------------------------------------------------------------------------

{- |
Expand Down
41 changes: 34 additions & 7 deletions Plutarch/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import GHC.Stack (HasCallStack)
import Numeric.Natural (Natural)
import Plutarch.Evaluate (evaluateScript)
import Plutus.V1.Ledger.Scripts (Script (Script))
import PlutusCore (Some, ValueOf)
import PlutusCore (Some (Some), ValueOf (ValueOf))
import qualified PlutusCore as PLC
import PlutusCore.DeBruijn (DeBruijn (DeBruijn), Index (Index))
import qualified UntypedPlutusCore as UPLC
Expand Down Expand Up @@ -303,22 +303,41 @@ punsafeConstant :: Some (ValueOf PLC.DefaultUni) -> Term s a
punsafeConstant = punsafeConstantInternal

punsafeConstantInternal :: Some (ValueOf PLC.DefaultUni) -> Term s a
punsafeConstantInternal c = Term $ \_ -> mkTermRes $ RConstant c
punsafeConstantInternal c = Term $ \_ ->
case c of
-- These constants are smaller than variable references.
Some (ValueOf PLC.DefaultUniBool _) -> mkTermRes $ RConstant c
Some (ValueOf PLC.DefaultUniUnit _) -> mkTermRes $ RConstant c
Some (ValueOf PLC.DefaultUniInteger n) | n < 256 -> mkTermRes $ RConstant c
_ ->
let hoisted = HoistedTerm (hashRawTerm $ RConstant c) (RConstant c)
in TermResult (RHoisted hoisted) [hoisted]

asClosedRawTerm :: ClosedTerm a -> TermResult
asClosedRawTerm t = asRawTerm t 0

-- FIXME: Give proper error message when mutually recursive.
phoistAcyclic :: HasCallStack => ClosedTerm a -> Term s a
phoistAcyclic t = Term $ \_ -> case asRawTerm t 0 of
-- FIXME: is this worth it?
t'@(getTerm -> RBuiltin _) -> t'
phoistAcyclic t = case asRawTerm t 0 of
-- Built-ins are smaller than variable references
t'@(getTerm -> RBuiltin _) -> Term $ \_ -> t'
t' -> case evaluateScript . Script $ UPLC.Program () (PLC.defaultVersion ()) (compile' t') of
Right _ ->
let hoisted = HoistedTerm (hashRawTerm . getTerm $ t') (getTerm t')
in TermResult (RHoisted hoisted) (hoisted : getDeps t')
in Term $ \_ -> TermResult (RHoisted hoisted) (hoisted : getDeps t')
Left e -> error $ "Hoisted term errs! " <> show e

-- Couldn't find a definition for this in plutus-core
subst :: Natural -> (Natural -> UPLC.Term DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()) -> UPLC.Term DeBruijn UPLC.DefaultUni UPLC.DefaultFun () -> UPLC.Term DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()
subst idx x (UPLC.Apply () yx yy) = UPLC.Apply () (subst idx x yx) (subst idx x yy)
subst idx x (UPLC.LamAbs () name y) = UPLC.LamAbs () name (subst (idx + 1) x y)
subst idx x (UPLC.Delay () y) = UPLC.Delay () (subst idx x y)
subst idx x (UPLC.Force () y) = UPLC.Force () (subst idx x y)
subst idx x (UPLC.Var () (DeBruijn (Index idx'))) | idx == idx' = x idx
subst idx _ y@(UPLC.Var () (DeBruijn (Index idx'))) | idx > idx' = y
subst idx _ (UPLC.Var () (DeBruijn (Index idx'))) | idx < idx' = UPLC.Var () (DeBruijn . Index $ idx' - 1)
subst _ _ y = y

rawTermToUPLC ::
(HoistedTerm -> Natural -> UPLC.Term DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()) ->
Natural ->
Expand All @@ -332,7 +351,15 @@ rawTermToUPLC m l (RLamAbs n t) =
(replicate (fromIntegral $ n + 1) $ UPLC.LamAbs () (DeBruijn . Index $ 0))
$ (rawTermToUPLC m (l + n + 1) t)
rawTermToUPLC m l (RApply x y) =
foldr (.) id ((\y' t -> UPLC.Apply () t (rawTermToUPLC m l y')) <$> y) $ (rawTermToUPLC m l x)
let f y t@(UPLC.LamAbs () _ body) =
case rawTermToUPLC m l y of
-- Inline unconditionally if it's a variable or built-in.
-- These terms are very small and are always WHNF.
UPLC.Var () (DeBruijn (Index idx)) -> subst 1 (\lvl -> UPLC.Var () (DeBruijn . Index $ idx + lvl - 1)) body
arg@UPLC.Builtin {} -> subst 1 (\_ -> arg) body
arg -> UPLC.Apply () t arg
f y t = UPLC.Apply () t (rawTermToUPLC m l y)
in foldr (.) id (f <$> y) $ (rawTermToUPLC m l x)
rawTermToUPLC m l (RDelay t) = UPLC.Delay () (rawTermToUPLC m l t)
rawTermToUPLC m l (RForce t) = UPLC.Force () (rawTermToUPLC m l t)
rawTermToUPLC _ _ (RBuiltin f) = UPLC.Builtin () f
Expand Down
9 changes: 6 additions & 3 deletions Plutarch/List.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import Numeric.Natural (Natural)
import qualified GHC.Generics as GHC
import Generics.SOP (Generic, I (I))
import Plutarch (
ClosedTerm,
PDelayed,
PType,
PlutusType,
Expand Down Expand Up @@ -183,10 +184,12 @@ ptryIndex n xs = phead # (pdrop n xs)
efficient in many circumstances.
-}
pdrop :: (PIsListLike list a) => Natural -> Term s (list a) -> Term s (list a)
pdrop n xs = (phoistAcyclic $ plam $ \x -> pdrop' n x) # xs
pdrop n xs = pdrop' n # xs
where
pdrop' 0 xs' = xs'
pdrop' n' xs' = pdrop' (n' - 1) (ptail # xs')
pdrop' :: (PIsListLike list a) => Natural -> ClosedTerm (list a :--> list a)
pdrop' 0 = plam $ \x -> x
pdrop' 1 = ptail
pdrop' n' = phoistAcyclic $ plam $ \x -> ptail #$ pdrop' (n' - 1) # x

--------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion Plutarch/Rational.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Plutarch.Rational (
PRational,
PRational (..),
preduce,
pnumerator,
pdenominator,
Expand Down
12 changes: 6 additions & 6 deletions examples/Examples/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,22 +286,22 @@ tests =

tripSumComp :: String
tripSumComp =
"(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> (\\i0 -> addInteger (addInteger (unIData (i4 i2)) (unIData (i4 i1))) (unIData (i4 (i5 i1)))) (i4 i1)) ((\\i0 -> force (force sndPair) (unConstrData i1)) i1)) (force headList)) i1) (force tailList)))"
"(program 1.0.0 ((\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> (\\i0 -> addInteger (addInteger (unIData (i4 i2)) (unIData (i4 i1))) (unIData (i4 (i5 i1)))) (i4 i1)) (force (force sndPair) (unConstrData i1))) (force headList)) (force tailList)))"

nFieldsComp :: String
nFieldsComp = "(program 1.0.0 ((\\i0 -> \\i0 -> addInteger (unIData (i2 i1)) (unIData (i2 (force tailList i1)))) (force headList)))"

dropFieldsComp :: String
dropFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 (i3 (i3 (i3 (i3 (i3 (i3 (i3 i1))))))))) (force headList)) i1) (force tailList)))"
dropFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 (i3 (i3 (i3 (i3 (i3 (i3 (i3 i1))))))))) (force headList)) (force tailList)))"

rangeFieldsComp :: String
rangeFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 (i3 (i3 (i3 (i3 i1)))))) (force headList)) i1) (force tailList)))"
rangeFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 (i3 (i3 (i3 (i3 i1)))))) (force headList)) (force tailList)))"

getYComp :: String
getYComp = "(program 1.0.0 (\\i0 -> force headList (force tailList ((\\i0 -> force (force sndPair) (unConstrData i1)) i1))))"
getYComp = "(program 1.0.0 (\\i0 -> force headList (force tailList (force (force sndPair) (unConstrData i1)))))"

tripYZComp :: String
tripYZComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 ((\\i0 -> force (force sndPair) (unConstrData i1)) i1))) (force headList)) i1) (force tailList)))"
tripYZComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> addInteger (unIData (i3 i1)) (unIData (i3 (i4 i1)))) (i3 (force (force sndPair) (unConstrData i1)))) (force headList)) (force tailList)))"

letSomeFieldsComp :: String
letSomeFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> (\\i0 -> addInteger (addInteger (unIData (i4 i2)) (unIData (i4 i1))) (unIData (i4 (i5 (i5 (i5 i1)))))) (i4 i1)) (i3 (i3 (i3 i1)))) (force headList)) i1) (force tailList)))"
letSomeFieldsComp = "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> (\\i0 -> addInteger (addInteger (unIData (i4 i2)) (unIData (i4 i1))) (unIData (i4 (i5 (i6 i1))))) (i5 i1)) (i4 (i3 i1))) (force headList)) (\\i0 -> i2 (i2 i1))) (force tailList)))"
Loading

0 comments on commit 7c64217

Please sign in to comment.