{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module Plutarch.Internal.PLam (
  PLamN,
  plam,
  pinl,
) where

import Data.Kind (Type)
import Data.Text qualified as Text
import GHC.Stack (HasCallStack, callStack, withFrozenCallStack)
import Plutarch.Builtin.String (ptraceInfo)
import Plutarch.Internal.PrettyStack (prettyStack)
import Plutarch.Internal.Term (
  Config (Tracing),
  PType,
  S,
  Term,
  pgetConfig,
  plam',
  punsafeConstantInternal,
  (:-->),
  pattern DoTracingAndBinds,
 )
import PlutusCore qualified as PLC

{- $plam
 Lambda abstraction.

 The 'PLamN' constraint allows
 currying to work as expected for any number of arguments.

 > id :: Term s (a :--> a)
 > id = plam (\x -> x)

 > const :: Term s (a :--> b :-> a)
 > const = plam (\x y -> x)
-}

mkstring :: Text.Text -> Term s a
mkstring :: forall (s :: S) (a :: PType). Text -> Term s a
mkstring Text
x = Some @Type (ValueOf DefaultUni) -> Term s a
forall (s :: S) (a :: PType).
Some @Type (ValueOf DefaultUni) -> Term s a
punsafeConstantInternal (Some @Type (ValueOf DefaultUni) -> Term s a)
-> Some @Type (ValueOf DefaultUni) -> Term s a
forall a b. (a -> b) -> a -> b
$ forall a (uni :: Type -> Type).
Contains @Type uni a =>
a -> Some @Type (ValueOf uni)
PLC.someValue @Text.Text @PLC.DefaultUni Text
x

class PLamN (a :: Type) (b :: PType) (s :: S) | a -> b, s b -> a where
  plam :: forall c. HasCallStack => (Term s c -> a) -> Term s (c :--> b)

instance {-# OVERLAPPABLE #-} a' ~ Term s a => PLamN a' a s where
  plam :: forall (c :: PType).
HasCallStack =>
(Term s c -> a') -> Term s (c :--> a)
plam Term s c -> a'
f =
    let cs :: CallStack
cs = CallStack
HasCallStack => CallStack
callStack
     in (Term s c -> Term s a) -> Term s (c :--> a)
forall (s :: S) (a :: PType) (b :: PType).
(Term s a -> Term s b) -> Term s (a :--> b)
plam' ((Term s c -> Term s a) -> Term s (c :--> a))
-> (Term s c -> Term s a) -> Term s (c :--> a)
forall a b. (a -> b) -> a -> b
$ \Term s c
x -> (Config -> Term s a) -> Term s a
forall (s :: S) (a :: PType). (Config -> Term s a) -> Term s a
pgetConfig ((Config -> Term s a) -> Term s a)
-> (Config -> Term s a) -> Term s a
forall a b. (a -> b) -> a -> b
$ \case
          -- Note: This works because at the moment, DoTracingAndBinds is the
          -- most general tracing mode.
          Tracing LogLevel
_ TracingMode
DoTracingAndBinds ->
            Term s PString -> Term s a -> Term s a
forall (a :: PType) (s :: S).
Term s PString -> Term s a -> Term s a
ptraceInfo (Text -> Term s PString
forall (s :: S) (a :: PType). Text -> Term s a
mkstring (Text -> Term s PString) -> Text -> Term s PString
forall a b. (a -> b) -> a -> b
$ Text -> CallStack -> Text
prettyStack Text
"L" CallStack
cs) (Term s a -> Term s a) -> Term s a -> Term s a
forall a b. (a -> b) -> a -> b
$ Term s c -> a'
f Term s c
x
          Config
_ -> Term s c -> a'
f Term s c
x

instance (a' ~ Term s a, PLamN b' b s) => PLamN (a' -> b') (a :--> b) s where
  plam :: forall (c :: PType).
HasCallStack =>
(Term s c -> a' -> b') -> Term s (c :--> (a :--> b))
plam Term s c -> a' -> b'
f = (HasCallStack => Term s (c :--> (a :--> b)))
-> Term s (c :--> (a :--> b))
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => Term s (c :--> (a :--> b)))
 -> Term s (c :--> (a :--> b)))
-> (HasCallStack => Term s (c :--> (a :--> b)))
-> Term s (c :--> (a :--> b))
forall a b. (a -> b) -> a -> b
$ (Term s c -> Term s (a :--> b)) -> Term s (c :--> (a :--> b))
forall (s :: S) (a :: PType) (b :: PType).
(Term s a -> Term s b) -> Term s (a :--> b)
plam' ((Term s c -> Term s (a :--> b)) -> Term s (c :--> (a :--> b)))
-> (Term s c -> Term s (a :--> b)) -> Term s (c :--> (a :--> b))
forall a b. (a -> b) -> a -> b
$ \Term s c
x -> (Term s a -> b') -> Term s (a :--> b)
forall a (b :: PType) (s :: S) (c :: PType).
(PLamN a b s, HasCallStack) =>
(Term s c -> a) -> Term s (c :--> b)
forall (c :: PType).
HasCallStack =>
(Term s c -> b') -> Term s (c :--> b)
plam (Term s c -> a' -> b'
f Term s c
x)

pinl :: Term s a -> (Term s a -> Term s b) -> Term s b
pinl :: forall (s :: S) (a :: PType) (b :: PType).
Term s a -> (Term s a -> Term s b) -> Term s b
pinl Term s a
v Term s a -> Term s b
f = Term s a -> Term s b
f Term s a
v