{-# LANGUAGE LambdaCase, TupleSections #-}
module ISQ.Lang.RAIICheck where
import Control.Monad.State
import ISQ.Lang.ISQv2Grammar
import ISQ.Lang.ISQv2Tokenizer (annotation, Annotated)
import Control.Monad.Except
import Data.Bifunctor
import Debug.Trace (trace)
import ISQ.Lang.ISQv2Grammar (AST(NProcedureInstantiated))

{-
Module for transforming AST tree into RAII form, eliminating intermediate return/continue/break.
Instead, intermediate control flow statements will be replaced by flag registers and conditionally-jump-to-end. Return values will also be replaced by return value register.
-}

data Region = Func {value :: Int, flag :: Int} | While {headFlag :: Int, flag :: Int} | For {flag :: Int} | Block {flag :: Int}  deriving (Show, Eq)

data RegionType = RFunc | RLoop | RBlock deriving (Show, Eq)

regionType :: Region->RegionType
regionType Func{} = RFunc
regionType While{} = RLoop
regionType For{} = RLoop
regionType Block{} = RBlock

data RAIICheckEnv = RAIICheckEnv{
    flagCounter :: Int,
    regionFlags :: [Region]
} deriving Show

newRAIIEnv :: RAIICheckEnv
newRAIIEnv = RAIICheckEnv 0 []

-- Before region: push a flag onto stack. The flag means ``whether the parent region should stop right after this op.''
-- On multi-region jump: scan through the regions and set all the flags to 1/set return value to 1.

{-
While-statement can be transformed from
while {cond} do {body}
to 
loop { if(cond) {body} }
thus, body can be seen as ``subregion'' of head.
-}


-- Break writes both headFlag and bodyFlag
-- Continue writes bodyFlag only
-- Thus, (Continue, bodyFlag):(Break, headFlag):(Return, funcFlag)
pushRegion :: Region->RAIICheck ()
pushRegion r = modify' (\s->s{regionFlags = r:regionFlags s})
popRegion :: RAIICheck ()
popRegion = modify' (\s->s{regionFlags = tail $ regionFlags s})

skippedRegions :: Pos->RegionType->RAIICheck [Region]
skippedRegions pos ty = do
    regions <- gets regionFlags
    let go [] = Nothing
        go (x:xs) = if regionType x == ty  then Just [x] else fmap (x:) (go xs)
    case go regions of
        Nothing -> throwError $ UnmatchedScopeError pos ty
        Just x->return x


data RAIIError =
      UnmatchedScopeError {unmatchedPos :: Pos, wantedRegionType :: RegionType}
    | PlaceHolder
    deriving (Eq, Show)


type RAIICheck = ExceptT RAIIError (State RAIICheckEnv)

nextId :: RAIICheck Int
nextId = do
    st<-get
    id<-gets flagCounter
    put st{flagCounter=id+1}
    return id


-- A for-loop is affine safe if it keeps affine property (i.e. will finish its all iterations instead of break/return).
data AffineSafe = Safe | ContainsBreak | ContainsContinue | ContainsReturn deriving Show

instance Semigroup AffineSafe where
    (<>) ContainsReturn _ = ContainsReturn
    (<>) _ ContainsReturn = ContainsReturn
    (<>) ContainsBreak _ = ContainsBreak
    (<>) _ ContainsBreak = ContainsBreak
    (<>) ContainsContinue Safe = ContainsContinue
    (<>) Safe ContainsContinue = ContainsContinue
    (<>) ContainsContinue ContainsContinue = ContainsContinue
    (<>) Safe Safe = Safe
instance Monoid AffineSafe where
    mempty = Safe


isExprSafe :: Expr ann->Expr (AffineSafe, ann)
isExprSafe x@(ERange _ _ _ Nothing) = fmap (Safe,) x
isExprSafe x@(ERange _ _ _ (Just (EIntLit _ val))) | val > 0 = fmap (Safe,) x
isExprSafe x@(ERange ann lo hi (Just (EUnary ann2 Neg (EIntLit ann3 v)))) = isExprSafe (ERange ann lo hi (Just (EIntLit ann2 (-v))))
isExprSafe x@(ERange ann lo hi (Just (EUnary ann2 Positive (EIntLit ann3 v)))) = isExprSafe (ERange ann lo hi (Just (EIntLit ann2 v)))
isExprSafe x@ERange{} = fmap (ContainsBreak,) x
isExprSafe x = fmap (Safe,) x
checkSafe :: (Annotated p)=>p (AffineSafe, ann)->AffineSafe
checkSafe = fst . annotation

bodyEffect :: AST (AffineSafe, b) -> AffineSafe
bodyEffect NFor{annotationAST=(ContainsBreak,_)}=Safe
bodyEffect NWhile{annotationAST=(ContainsBreak,_)}=Safe
bodyEffect x = checkSafe x

isStatementAffineSafe' :: (AST ann->AST (AffineSafe, ann))->(AST ann->AST (AffineSafe, ann))
isStatementAffineSafe' f (NBlock ann lis) = let lis' = fmap f lis in NBlock (mconcat $ fmap bodyEffect lis', ann) lis'
isStatementAffineSafe' f (NIf ann cond b1 b2) = let
        b1' = fmap f b1;
        b2' = fmap f b2
    in NIf (mconcat $ fmap bodyEffect $ b1' ++ b2', ann) (isExprSafe cond) b1' b2'
isStatementAffineSafe' f (NFor ann v range body) = let b' = fmap f body; r' = isExprSafe range in NFor (mconcat $ checkSafe r':fmap bodyEffect b', ann) v r' b'
isStatementAffineSafe' f (NWhile ann cond lis) = let lis' = fmap f lis; cond' = isExprSafe cond in NWhile (mconcat $ fmap bodyEffect lis', ann) cond' lis'
isStatementAffineSafe' _ x@(NBreak _) = fmap (ContainsBreak,) x
isStatementAffineSafe' _ x@(NContinue _) = fmap (ContainsContinue,) x
isStatementAffineSafe' _ x@(NReturn _ _) = fmap (ContainsReturn,) x
isStatementAffineSafe' f x@(NProcedureWithDerive ann rty name targs args body Nothing ret) =
    let body' = fmap f body 
    in NProcedureWithDerive (Safe, ann) (fmap (Safe,) rty) name (fmap (first (fmap (Safe,))) targs) (fmap (first (fmap (Safe,))) args) body' Nothing $ fmap (Safe,) ret
isStatementAffineSafe' f x@(NProcedureInstantiated ann rty name targs args body Nothing False ret) =
    let body' = fmap f body 
    in NProcedureInstantiated (Safe, ann) (fmap (Safe,) rty) name (fmap (fmap (Safe,)) targs) (fmap (first (fmap (Safe,))) args) body' Nothing False $ fmap (Safe,) ret
isStatementAffineSafe' _ x = fmap (Safe,) x

isStatementAffineSafe :: AST ann -> AST (AffineSafe, ann)
isStatementAffineSafe = fix isStatementAffineSafe'

eraseSafe :: (Functor p)=>p (a, b) -> p b
eraseSafe = fmap snd
eliminateNonAffineForStmts' :: (AST (AffineSafe, ann) -> RAIICheck [AST (AffineSafe, ann)]) -> AST (AffineSafe, ann) -> RAIICheck [AST (AffineSafe, ann)]
eliminateNonAffineForStmts' f (NBlock a lis) = do
    lis' <- mapM f lis;
    return [NBlock a $ concat lis']
eliminateNonAffineForStmts' f (NIf a b c d) = do
    c' <- mapM f c;
    d' <- mapM f d;
    return [NIf a b (concat c') (concat d')]
eliminateNonAffineForStmts' f (NWhile a b c) = do
    c' <- mapM f c;
    return [NWhile a b (concat c')]
eliminateNonAffineForStmts' f (NFor (s1, ann) v eident@(EIdent (s2, eann) ident) body) = do
    i <- nextId
    let var_name = show i
    let sub = ESubscript (Safe, eann) eident $ EIdent (s2, eann) var_name
    let def = NDefvar (Safe, ann) [(Type (Safe, ann) Int [], v, Just sub, Nothing)]
    let lo = EIntLit (Safe, ann) 0
    let hi = EArrayLen (Safe, ann) eident
    let inc = EIntLit (Safe, ann) 1
    let range = ERange (Safe, ann) (Just lo) (Just hi) (Just inc)
    f $ NFor (s1, ann) var_name range $ def : body
eliminateNonAffineForStmts' f (NFor ann v elist@(EList eann lis) body) = do
    i <- nextId
    let array_name = show i
    let len = length lis
    let def = NDefvar ann [(Type ann (Array len) [Type ann Int []], array_name, Just elist, Nothing)]
    for <- f $ NFor ann v (EIdent eann array_name) body
    return [NBlock ann $ def : for]
eliminateNonAffineForStmts' f (NFor (s1, ann) v (ERange (s2, ann2) (Just a) (Just b) Nothing) body) = eliminateNonAffineForStmts' f (NFor (s1, ann) v (ERange (s2, ann2) (Just a) (Just b) (Just (EIntLit (Safe, ann) 1))) body)
eliminateNonAffineForStmts' f (NFor (Safe, ann) v expr body) = do
    b' <- mapM f body;
    return [NFor (Safe, ann) v expr (concat b')]
eliminateNonAffineForStmts' f (NFor sann vn (ERange sann2 (Just a) (Just b) (Just c)) body) = do
    idlo <- nextId;
    idhi <- nextId;
    idstep <- nextId;

    -- Add 'vn += step;' before 'continue;'
    let incStep (NContinue ann) = let v = EIdent ann vn in [NAssign ann v c AddEq, NContinue ann]
        incStep (NBlock ann lis) = [NBlock ann $ concatMap incStep lis]
        incStep (NIf ann cond ifLis elseLis) = [NIf ann cond (concatMap incStep ifLis) $ concatMap incStep elseLis]
        -- Note that no need to add in nest loops
        incStep x = [x]

    b' <- mapM f $ concatMap incStep body;
    let v = EIdent sann vn
        lo = ETempVar sann2 idlo
        hi = ETempVar sann2 idhi
        step = ETempVar sann2 idstep
        left = EBinary sann2 Mul v step
        right = EBinary sann2 Mul hi step
        block = [
            NTempvar sann (intType (), idlo, Just a),
            NTempvar sann (intType (), idhi, Just b),
            NTempvar sann (intType (), idstep, Just c),
            NDefvar sann [(intType sann, vn, Just lo, Nothing)],
            NWhile sann (EBinary sann2 (Cmp Less) left right) (concat b' ++ [NAssign sann v (EBinary sann2 Add v step) AssignEq])]
        in return [NBlock sann block]
eliminateNonAffineForStmts' f NFor{} = error "For-statement with non-standard range indices not supported."
eliminateNonAffineForStmts' f (NProcedureWithDerive ann rty name targs args body derive ret) = do
    body' <- mapM f body
    return [NProcedureWithDerive ann rty name targs args (concat body') derive ret]
eliminateNonAffineForStmts' f (NProcedureInstantiated ann rty name targs args body derive False ret) = do
    body' <- mapM f body
    return [NProcedureInstantiated ann rty name targs args (concat body') derive False ret]

eliminateNonAffineForStmts' _ x = return [x]

eliminateNonAffineForStmts :: AST (AffineSafe, ann) -> RAIICheck [AST (AffineSafe, ann)]
eliminateNonAffineForStmts = fix eliminateNonAffineForStmts'


checkCurrentScope :: Pos->RAIICheck LAST
checkCurrentScope pos = do
    scope<-gets (flag.head.regionFlags)
    return $ NJumpToEndOnFlag pos (ETempVar pos scope)

tempBool :: ann->Int->AST ann
tempBool ann i = NTempvar ann (boolType (), i, Just (EBoolLit ann False))

raiiTransform' :: (AST (AffineSafe, Pos) -> RAIICheck [LAST]) -> (AST (AffineSafe, Pos) -> RAIICheck [LAST])
raiiTransform' f (NBlock (s, ann) stmnLis) = do
    i <- nextId
    pushRegion (Block i)
    lis' <- mapM f stmnLis
    popRegion
    case s of
        Safe -> return [tempBool ann i, NBlock ann (concat lis')]
        _ -> do
            finalize <- checkCurrentScope ann
            return [tempBool ann i, NBlock ann (concat lis'), finalize]
raiiTransform' f (NIf (s, ann) cond t e) = do
    i<-nextId
    pushRegion (Block i)
    t' <- mapM f t
    e' <- mapM f e
    popRegion
    case s of
        Safe -> return [tempBool ann i, NIf ann (eraseSafe cond) (concat t') (concat e')]
        _ -> do
            finalize <- checkCurrentScope ann
            return [tempBool ann i, NIf ann (eraseSafe cond) (concat t') (concat e'), finalize]
raiiTransform' f (NFor (s, ann) var range b) = do
    i<-nextId
    pushRegion (For i)
    b'<-mapM f b
    popRegion
    case s of
        Safe -> return [NFor ann var (eraseSafe range) $ tempBool ann i : concat b']
        _ -> do
            finalize <- checkCurrentScope ann
            return [NFor ann var (eraseSafe range) $ tempBool ann i : concat b', finalize]
raiiTransform' f (NWhile (s, ann) cond body) = do
    ihead<-nextId
    ibody<-nextId
    pushRegion (While ihead ibody)
    b' <- mapM f body
    popRegion
    case s of
        Safe -> return [tempBool ann ihead, NWhileWithGuard ann (eraseSafe cond) (tempBool ann ibody : concat b') $ ETempVar ann ihead]
        _ -> do
            finalize <- checkCurrentScope ann
            return [tempBool ann ihead, NWhileWithGuard ann (eraseSafe cond) (tempBool ann ibody : concat b') $ ETempVar ann ihead, finalize]
raiiTransform' f ast@(NProcedureWithDerive (_, a) rty name targs args body derive ret) = do
    case derive of
        Just DeriveOracle -> return [eraseSafe ast]
        _ -> do
            procRet<-nextId
            procFlag<-nextId
            pushRegion (Func procRet procFlag)
            body' <- mapM f (tempBool (Safe, a) procFlag : body)
            popRegion
            -- no finalizer
            return [NProcedureWithDerive a (eraseSafe rty) name (map (first eraseSafe) targs) (map (first eraseSafe) args) (concat body') derive (ETempVar a procRet)]
raiiTransform' f (NProcedureInstantiated (_, a) rty name targs args body derive ori ret) = do
    procRet <- nextId
    procFlag <- nextId
    pushRegion (Func procRet procFlag)
    body' <- mapM f (tempBool (Safe, a) procFlag : body)
    popRegion
    -- no finalizer
    return [NProcedureInstantiated a (eraseSafe rty) name (map eraseSafe targs) (map (first eraseSafe) args) (concat body') derive False (ETempVar a procRet)]

-- The transformations below should also work with labeled loops.
raiiTransform' f (NBreak (_, ann)) = do
    regions<-skippedRegions ann RLoop
    let break_all_loops = concatMap (\case
            While h f-> [setFlag ann h]
            _ -> []) regions
    let break_passing_bodies = setFlag ann . flag <$> regions
    return $ break_all_loops ++ break_passing_bodies ++ [NJumpToEnd ann]

raiiTransform' f (NContinue (_, ann)) = do
    regions<-skippedRegions ann RLoop
    let break_passing_bodies = setFlag ann . flag <$> tail regions
    return $ break_passing_bodies ++ [NJumpToEnd ann]

raiiTransform' f (NReturn (_, ann) val) = do
    regions<-skippedRegions ann RFunc
    let break_all_loops  = concatMap (\case
            While h f-> [setFlag ann h]
            _ -> []) regions
    let break_passing_bodies = setFlag ann . flag <$> tail regions
    let Func v f = last regions
    case val of
        EUnitLit _ -> return $ break_all_loops ++ break_passing_bodies ++ [NJumpToEnd ann]
        _ -> return $ break_all_loops ++ break_passing_bodies ++ [setReturnVal ann v $ eraseSafe val, NJumpToEnd ann]

raiiTransform' _ ast = return [eraseSafe ast]


setFlag :: ann->Int->AST ann
setFlag ann x= NAssign ann (ETempVar ann x) (EBoolLit ann True) AssignEq
setReturnVal :: ann->Int->Expr ann->AST ann
setReturnVal ann x y = NAssign ann (ETempVar ann x) y AssignEq
raiiTransform :: AST (AffineSafe, Pos) -> RAIICheck [LAST]
raiiTransform = fix raiiTransform'


raiiCheck' :: [LAST]->RAIICheck [LAST]
raiiCheck' ast = do
    let safeCheck = fmap isStatementAffineSafe ast
    eliminateFor <- concat <$> mapM eliminateNonAffineForStmts safeCheck
    concat <$> mapM raiiTransform eliminateFor

raiiCheck ast = evalState (runExceptT (raiiCheck' ast)) newRAIIEnv

-- 

--isStatementAffineSafe b2
--isStatementAffineSafe 
--raiiTransform :: [LAST]->[LAST]
