{-# LANGUAGE PatternSynonyms #-}

module Plutarch.Pretty.Internal.TermUtils (
  unwrapLamAbs,
  unwrapBindings,
  unwrapApply,
  incrVar,
  pattern PFixAst,
  pattern ComposeAST,
  pattern IfThenElseLikeAST,
) where

import UntypedPlutusCore (
  DeBruijn (DeBruijn),
  Index,
  Term (Apply, Delay, Force, LamAbs, Var),
 )

unwrapLamAbs :: Index -> Term name uni fun ann -> (Index, Term name uni fun ann)
unwrapLamAbs :: forall name (uni :: Type -> Type) fun ann.
Index -> Term name uni fun ann -> (Index, Term name uni fun ann)
unwrapLamAbs Index
d (LamAbs ann
_ name
_ Term name uni fun ann
t) = Index -> Term name uni fun ann -> (Index, Term name uni fun ann)
forall name (uni :: Type -> Type) fun ann.
Index -> Term name uni fun ann -> (Index, Term name uni fun ann)
unwrapLamAbs (Index
d Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1) Term name uni fun ann
t
unwrapLamAbs Index
d Term name uni fun ann
a = (Index
d, Term name uni fun ann
a)

unwrapBindings :: [Term name uni fun ann] -> Term name uni fun ann -> ([Term name uni fun ann], Term name uni fun ann)
unwrapBindings :: forall name (uni :: Type -> Type) fun ann.
[Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
unwrapBindings [Term name uni fun ann]
l (Apply ann
_ (LamAbs ann
_ name
_ Term name uni fun ann
t) Term name uni fun ann
arg) = [Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
forall name (uni :: Type -> Type) fun ann.
[Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
unwrapBindings (Term name uni fun ann
arg Term name uni fun ann
-> [Term name uni fun ann] -> [Term name uni fun ann]
forall a. a -> [a] -> [a]
: [Term name uni fun ann]
l) Term name uni fun ann
t
unwrapBindings [Term name uni fun ann]
l Term name uni fun ann
a = ([Term name uni fun ann]
l, Term name uni fun ann
a)

unwrapApply ::
  [Term name uni fun ann] ->
  Term name uni fun ann ->
  ([Term name uni fun ann], Term name uni fun ann)
unwrapApply :: forall name (uni :: Type -> Type) fun ann.
[Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
unwrapApply [Term name uni fun ann]
l (Apply ann
_ Term name uni fun ann
t Term name uni fun ann
arg) = [Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
forall name (uni :: Type -> Type) fun ann.
[Term name uni fun ann]
-> Term name uni fun ann
-> ([Term name uni fun ann], Term name uni fun ann)
unwrapApply (Term name uni fun ann
arg Term name uni fun ann
-> [Term name uni fun ann] -> [Term name uni fun ann]
forall a. a -> [a] -> [a]
: [Term name uni fun ann]
l) Term name uni fun ann
t
unwrapApply [Term name uni fun ann]
l Term name uni fun ann
arg = ([Term name uni fun ann]
l, Term name uni fun ann
arg)

-- AST resulting from `pfix`. This is always constant.
pattern PFixAst :: Term name uni fun ()
pattern $mPFixAst :: forall {r} {name} {uni :: Type -> Type} {fun}.
Term name uni fun () -> ((# #) -> r) -> ((# #) -> r) -> r
PFixAst <-
  LamAbs
    ()
    _
    ( Apply
        ()
        ( LamAbs
            ()
            _
            ( Apply
                ()
                (Var () _)
                ( LamAbs
                    ()
                    _
                    ( Apply
                        ()
                        ( Apply
                            ()
                            (Var () _)
                            (Var () _)
                          )
                        (Var () _)
                      )
                  )
              )
          )
        ( LamAbs
            ()
            _
            ( Apply
                ()
                (Var () _)
                ( LamAbs
                    ()
                    _
                    ( Apply
                        ()
                        ( Apply
                            ()
                            (Var () _)
                            (Var () _)
                          )
                        (Var () _)
                      )
                  )
              )
          )
      )

-- If `f` and `g` are Var references, their indices are incremented once since they are within a lambda.
pattern ComposeAST :: Term DeBruijn uni fun () -> Term DeBruijn uni fun () -> Term DeBruijn uni fun ()
pattern $mComposeAST :: forall {r} {uni :: Type -> Type} {fun}.
Term DeBruijn uni fun ()
-> (Term DeBruijn uni fun () -> Term DeBruijn uni fun () -> r)
-> ((# #) -> r)
-> r
ComposeAST f g <- LamAbs () _ (Apply () (incrVar -> f) (Apply () (incrVar -> g) (Var () (DeBruijn 1))))

{- This AST represents a typical if/then/else usage if and only if 'ifThenElseMaybe' is either the
builtin IfThenElse (forced once), or a reference to such.
-}
pattern IfThenElseLikeAST ::
  Term name uni fun () ->
  Term name uni fun () ->
  Term name uni fun () ->
  Term name uni fun () ->
  Term name uni fun ()
pattern $mIfThenElseLikeAST :: forall {r} {name} {uni :: Type -> Type} {fun}.
Term name uni fun ()
-> (Term name uni fun ()
    -> Term name uni fun ()
    -> Term name uni fun ()
    -> Term name uni fun ()
    -> r)
-> ((# #) -> r)
-> r
IfThenElseLikeAST ifThenElseMaybe cond trueBranch falseBranch <-
  Force
    ()
    ( Apply
        ()
        ( Apply
            ()
            ( Apply
                ()
                ifThenElseMaybe
                cond
              )
            (Delay () trueBranch)
          )
        (Delay () falseBranch)
      )

-- | Increment the debruijn index of a 'Var', leave any other AST node unchanged.
incrVar :: Term DeBruijn uni fun () -> Term DeBruijn uni fun ()
incrVar :: forall (uni :: Type -> Type) fun.
Term DeBruijn uni fun () -> Term DeBruijn uni fun ()
incrVar (Var () (DeBruijn Index
n)) = () -> DeBruijn -> Term DeBruijn uni fun ()
forall name (uni :: Type -> Type) fun ann.
ann -> name -> Term name uni fun ann
Var () (DeBruijn -> Term DeBruijn uni fun ())
-> (Index -> DeBruijn) -> Index -> Term DeBruijn uni fun ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index -> DeBruijn
DeBruijn (Index -> Term DeBruijn uni fun ())
-> Index -> Term DeBruijn uni fun ()
forall a b. (a -> b) -> a -> b
$ Index
n Index -> Index -> Index
forall a. Num a => a -> a -> a
- Index
1
incrVar Term DeBruijn uni fun ()
n = Term DeBruijn uni fun ()
n