{-# LANGUAGE LambdaCase, TupleSections, ViewPatterns #-}
module ISQ.Lang.TypeCheck where
import ISQ.Lang.ISQv2Grammar hiding (AmbiguousSymbol, ArgNumberMismatch, RedefinedSymbol, UnsupportedType)
import ISQ.Lang.ISQv2Tokenizer ( Annotated(..) )
import Data.Bifunctor (first)
import qualified Data.Map.Lazy as Map
import qualified Data.MultiMap as MultiMap
import Data.Ix (inRange)
import Data.List.Extra ( firstJust, zip4 )
import Data.Maybe ( catMaybes, isJust, fromMaybe )
import qualified Data.Set as Set
import Control.Monad (join)
import Control.Monad.Except
    ( when,
      fix,
      MonadTrans(lift),
      ExceptT,
      void,
      foldM,
      zipWithM,
      liftEither,
      runExceptT,
      MonadError(throwError) )
import Control.Monad.State.Lazy
    ( when,
      fix,
      MonadTrans(lift),
      StateT(runStateT),
      void,
      foldM,
      zipWithM,
      gets,
      modify,
      modify',
      evalState,
      State )
import Debug.Trace ( traceM, trace )

type EType = Type ()
type TCAST = AST TypeCheckData

data TypeCheckData = TypeCheckData{
    sourcePos :: Pos,
    termType :: EType,
    termId :: Int
} deriving (Eq, Show)

type TypeCheckInfo = TypeCheckData

data Symbol = SymVar String | SymTempVar Int | SymTempArg Int deriving (Show, Eq, Ord)

getSymbolName :: Symbol -> String
getSymbolName sym = case sym of {SymVar str -> str; _ -> ""}

data TypeCheckError =
      RedefinedSymbol { pos :: Pos, symbolName :: Symbol, firstDefinedAt :: Pos}
    | UndefinedSymbol { pos :: Pos, symbolName :: Symbol}
    | AmbiguousSymbol { pos :: Pos, symbolName :: Symbol, firstDefinedAt :: Pos, secondDefinedAt :: Pos}
    | TypeMismatch {pos :: Pos, expectedType :: [MatchRule], actualType :: Type ()}
    | UnsupportedStatement { pos :: Pos }
    | UnsupportedType { pos :: Pos, actualType :: Type () }
    | UnsupportedLeftSide { pos :: Pos }
    | ViolateNonCloningTheorem { pos :: Pos }
    | GateNameError { pos :: Pos }
    | ArgNumberMismatch { pos :: Pos, expectedArgs :: Int, actualArgs :: Int }
    | BadProcedureArgType { pos :: Pos, arg :: (Type (), Ident)}
    | BadProcedureReturnType { pos :: Pos, ret :: (Type (), Ident)}
    | BadGateSignature { pos :: Pos }
    | BadPermutationShape { pos :: Pos }
    | BadPermutationValue { pos :: Pos }
    | ICETypeCheckError
    | MainUndefined
    | BadMainSignature { actualMainSignature :: Type () }
    deriving (Eq, Show)
type SymbolTableLayer = MultiMap.MultiMap Symbol DefinedSymbol
type SymbolTable = [SymbolTableLayer]

querySymbol :: Symbol -> SymbolTable -> [DefinedSymbol]
querySymbol sym [] = []
querySymbol sym (x:xs) = case MultiMap.lookup sym x of
    [] -> querySymbol sym xs
    lis -> lis

insertSymbol :: Symbol->DefinedSymbol->SymbolTable->Either TypeCheckError SymbolTable
insertSymbol sym ast [] = insertSymbol sym ast [MultiMap.empty]
insertSymbol sym ast (x:xs) = case MultiMap.lookup sym x of
    [] -> Right $ MultiMap.insert sym ast x : xs
    (y:ys) -> Left $ RedefinedSymbol (definedPos ast) sym (definedPos y)


data TypeCheckEnv = TypeCheckEnv {
    symbolTable :: SymbolTable,
    ssaAllocator :: Int,
    mainDefined :: Bool,
    inOracle :: Bool,

    -- Store the levels of lambda expressions
    inLambda :: Int
}

type TypeCheck = ExceptT TypeCheckError (State TypeCheckEnv)

data DefinedSymbol = DefinedSymbol{
    definedPos :: Pos,
    definedType :: EType,
    definedSSA :: Int,
    isGlobal :: Bool,
    isDerive :: Bool,
    qualifiedName :: String
} deriving (Show)

addSym :: Symbol->DefinedSymbol->TypeCheck ()
addSym k v = do
    symtable<-gets symbolTable
    new_table <-liftEither $ insertSymbol k v symtable
    modify' (\x->x{symbolTable=new_table})

getSym :: Pos->Symbol->TypeCheck DefinedSymbol
getSym pos k = do
    symtable<-gets symbolTable
    lambda <- gets inLambda
    let symtable' = if lambda <= 0 then symtable else take lambda symtable
    let lis = querySymbol k symtable'
    case lis of
        [] -> throwError $ UndefinedSymbol pos k
        [x] -> return x
        (x:y:rest) -> throwError $ AmbiguousSymbol pos k (definedPos x) (definedPos y)

defineSym :: Symbol->Pos->EType->TypeCheck Int
defineSym a b c= do
    ssa<-nextId
    addSym a (DefinedSymbol b c ssa False False "")
    return ssa

defineGlobalSym :: String -> String -> Pos -> EType -> Bool -> Bool -> TypeCheck Int
defineGlobalSym prefix name b c logic d = do
    ssa<-nextId
    when (name == "main" && c /= Type () FuncTy [Type () Unit []] && c /= Type () FuncTy [Type () Unit [], Type () (Array 0) [intType ()], Type () (Array 0) [doubleType ()]]) $ do
        throwError $ BadMainSignature c
    when (name == "main") $ do
        modify' (\x->x{mainDefined = True})
    let qualifiedName = prefix ++ name
    addSym (SymVar name) (DefinedSymbol b c ssa True d qualifiedName)
    return ssa

setSym :: Symbol -> Pos -> TypeCheckData -> TypeCheck Int
setSym sym pos (TypeCheckData _ ty rid) = do
    sym_tables <- gets symbolTable
    let cur_table = head sym_tables
    let deleted = MultiMap.delete sym cur_table
    let new_data = DefinedSymbol pos ty rid False False ""
    let new_curr = MultiMap.insert sym new_data deleted
    modify' (\x -> x{symbolTable=new_curr : tail sym_tables})
    return rid

scope :: TypeCheck ()
scope = modify (\x->x{symbolTable = MultiMap.empty:symbolTable x})

lambdaScope :: TypeCheck ()
lambdaScope = do
    scope
    lambda <- gets inLambda
    modify (\x->x{inLambda = lambda + 1})

unscope :: TypeCheck SymbolTableLayer
unscope = do
    x <- gets (head . symbolTable)
    modify' (\x->x{symbolTable = tail $ symbolTable x})
    return x

lambdaUnscope :: TypeCheck SymbolTableLayer
lambdaUnscope = do
    lambda <- gets inLambda
    modify (\x->x{inLambda = lambda - 1})
    unscope

astType :: (Annotated p)=>p TypeCheckData->EType
astType = termType . annotation

nextId :: TypeCheck Int
nextId = do
    id<-gets ssaAllocator
    modify' (\x->x{ssaAllocator = id+1})
    return id

typeToInt :: Type () -> Int
typeToInt (Type () Bool []) = 3
typeToInt (Type () Int []) = 2
typeToInt (Type () Double []) = 1
typeToInt (Type () Complex []) = 0
typeToInt (Type () Ref [sub_type]) = typeToInt sub_type
typeToInt _ = -1

intToType :: Int -> Type ()
intToType 3 = boolType ()
intToType 2 = intType ()
intToType 1 = doubleType ()
intToType 0 = complexType ()
intToType _ = error "Unreachable."

-- Translate the int in Gate type to its actual meanning
intToQbit ann (-1) = refType ann $ qbitType ann
intToQbit ann v = Type ann (Array v) [qbitType ann]

type TCExpr = Expr TypeCheckData

data MatchRule = Exact EType | AnyUnknownList | AnyKnownList Int | AnyList | AnyFunc | AnyGate | AnyRef | AnyTempl
    | ArrayType MatchRule | FixedArray MatchRule
    deriving (Show, Eq)

checkRule :: MatchRule->EType->Bool
checkRule (Exact x) y = x==y
checkRule AnyUnknownList (Type () (Array 0) [_]) = True
checkRule (AnyKnownList x) (Type () (Array y) [_]) = x==y
checkRule AnyList (Type () (Array _) [_]) = True
checkRule AnyFunc (Type () FuncTy _) = True
checkRule AnyGate (Type () (Gate _) _) = True
checkRule AnyRef (Type () Ref [_]) = True
checkRule AnyTempl (Type () (Templ _) _) = True
checkRule (ArrayType subRule) (Type () (Array _) [subType]) = checkRule subRule subType
checkRule (FixedArray subRule) (Type () (Array 0) [subType]) = False
checkRule (FixedArray subRule) (Type () (Array _) [subType]) = checkRule subRule subType
checkRule _ _ = False

-- try to match two types, using auto dereference and int-to-bool implicit conversion.
matchType' :: [MatchRule]->TCExpr->TypeCheck (Maybe TCExpr)
matchType' wanted e = do
    let current_type = astType e
    let pos = sourcePos $ annotation e
    if any (`checkRule` current_type) wanted then return $ Just e
    else
        case current_type of
            -- Auto dereference rule
            Type () Ref [x] -> do
                id<-nextId
                matchType' wanted (EDeref (TypeCheckData pos x id) e)
            -- Bool-to-int implicit cast
            Type () Bool [] -> do
                id<-nextId
                matchType' wanted (EImplicitCast (TypeCheckData pos (Type () Int [] ) id) e)
            -- int to bool/double implicit cast
            Type () Int [] -> do
                id<-nextId
                case wanted of
                    [Exact (Type () Bool [])] -> return $ Just $ EImplicitCast (TypeCheckData pos (boolType ()) id) e
                    _ -> matchType' wanted (EImplicitCast (TypeCheckData pos (Type () Double [] ) id) e)
            -- float to int/complex implicit cast
            Type () Double [] -> do
                id<-nextId
                case wanted of
                    [Exact (Type () Int [])] -> return $ Just $ EImplicitCast (TypeCheckData pos (intType ()) id) e
                    [Exact (Type () Complex [])] -> return $ Just $ EImplicitCast (TypeCheckData pos (complexType ()) id) e
                    _ -> return Nothing
            -- Auto cast. Only the first rule is considered
            Type () (Array 0) [y] -> case head wanted of
                    Exact (Type () (Array neg) [y]) | neg < 0 -> return $ Just e
                    Exact (Type () (Array x) [y]) -> do
                        id <- nextId
                        matchType' wanted (EListCast (TypeCheckData pos (Type () (Array x) [y]) id) e)
                    _ -> return Nothing
            -- Auto list erasure
            Type () (Array x) [y] -> do
                id<-nextId
                case head wanted of
                    Exact (Type () (Array neg) [y]) | neg < 0 -> return $ Just e
                    _ -> matchType' wanted (EListCast (TypeCheckData pos (Type () (Array 0) [y]) id) e)
            Type () (Gate x) subx -> do
                case head wanted of
                    Exact (Type () (Gate y) suby) ->
                        if (subx == suby && length x == length y) && and (zipWith (\a b -> a == b || b < -1) x y)
                        then return $ Just e else return Nothing
                    _ -> return Nothing
            _ -> return Nothing
matchType :: [MatchRule]->TCExpr->TypeCheck TCExpr
matchType wanted e = do
    new_e<-matchType' wanted e
    case new_e of
        Just x->return x
        Nothing -> throwError $ TypeMismatch (sourcePos $ annotation e) wanted (astType e)

exactBinaryCheck :: (Expr Pos->TypeCheck (Expr TypeCheckData)) -> EType -> Pos -> BinaryOperator -> LExpr -> LExpr -> TypeCheck (Expr TypeCheckData)
exactBinaryCheck f etype pos op lhs rhs = do
    ref_lhs <- f lhs
    ref_rhs <- f rhs
    lhs' <- matchType [Exact etype] ref_lhs
    rhs' <- matchType [Exact etype] ref_rhs
    ssa <- nextId
    return $ EBinary (TypeCheckData pos etype ssa) op lhs' rhs'

getCommonType :: Expr TypeCheckData -> Expr TypeCheckData -> Int -> TypeCheck (Expr TypeCheckData, Expr TypeCheckData)
getCommonType lhs rhs upper_bound = do
    let li = typeToInt $ astType lhs
    when (li < 0) $ throwError $ do
        let ann = annotationExpr lhs
        UnsupportedType (sourcePos ann) (termType ann)
    let ri = typeToInt $ astType rhs
    when (ri < 0) $ throwError $ do
        let ann = annotationExpr rhs
        UnsupportedType (sourcePos ann) (termType ann)
    let min = minimum [li, ri, upper_bound]
    let min_type = intToType min
    matched_lhs <- matchType [Exact min_type] lhs
    matched_rhs <- matchType [Exact min_type] rhs
    return (matched_lhs, matched_rhs)

buildKetExpr :: Pos -> BinaryOperator -> Expr TypeCheckData -> Expr TypeCheckData -> TypeCheck (Expr TypeCheckData)
buildKetExpr pos op ref_lhs ref_rhs = do
    ssa <- nextId
    (matched_lhs, matched_rhs) <- case astType ref_lhs of
        Type () Ket [] -> do
            rhs' <- matchType [Exact $ ketType ()] ref_rhs
            return (ref_lhs, rhs')
        _ -> getCommonType ref_lhs ref_rhs 2
    return $ EBinary (TypeCheckData pos (astType matched_lhs) ssa) op matched_lhs matched_rhs

buildBinaryExpr :: Pos -> BinaryOperator -> Expr TypeCheckData -> Expr TypeCheckData -> TypeCheck (Expr TypeCheckData)
buildBinaryExpr pos op ref_lhs ref_rhs = do
    ssa <- nextId
    logic <- gets inOracle
    (if logic then (do
        let ty = astType ref_lhs
        case ty of
            Type () Bool [] -> do
                rhs <- matchType [Exact ty] ref_rhs
                return $ EBinary (TypeCheckData pos ty ssa) op ref_lhs rhs
            Type () (Array _) [Type () Bool []] -> do
                rhs <- matchType [Exact ty] ref_rhs
                return $ EBinary (TypeCheckData pos ty ssa) op ref_lhs rhs
            other -> throwError $ UnsupportedType pos other) else (do
        (matched_lhs, matched_rhs) <- getCommonType ref_lhs ref_rhs 3
        --traceM $ show matched_lhs
        let return_type = case op of
                Cmp _ -> boolType ()
                _ -> astType matched_lhs
        case op of
            Mod -> if return_type /= intType () then throwError $ TypeMismatch pos [Exact (intType ())] return_type
                else return $ EBinary (TypeCheckData pos return_type ssa) op matched_lhs matched_rhs
            _ -> return $ EBinary (TypeCheckData pos return_type ssa) op matched_lhs matched_rhs))

-- All leaf nodes have their own type, and intermediate types are calculated.
typeCheckExpr' :: (Expr Pos->TypeCheck (Expr TypeCheckData))->Expr Pos->TypeCheck (Expr TypeCheckData)
typeCheckExpr' f (ELambda pos (ty, ident) init eval) = do
    init' <- f init

    scope
    setSym (SymVar ident) pos $ annotationExpr init'
    eval' <- f eval
    let rty = astType eval'
    unscope

    ssa <- nextId
    return $ ELambda (TypeCheckData pos rty ssa) (ty, ident) init' eval'
typeCheckExpr' f (ELambdaRec pos (ty, ident) args init eval) = do
    scope
    let func_ty = Type () FuncTy $ ty : map (void . fst) args
    defineGlobalSym "" ident pos func_ty False False
    scope
    let defArg (ty, v) = do
            let ty' = void ty
            v' <- defineSym (SymVar v) pos ty'
            return (ty', v')
    args' <- mapM defArg args
    init' <- f init
    unscope
    eval' <- f eval
    unscope
    ssa <- nextId
    return $ EResolvedRec (TypeCheckData pos ty ssa) (ty, ident) args' init' eval'
typeCheckExpr' f (EIf pos cond ifexp elseexp) = do
    cond' <- f cond
    cond'' <- matchType [Exact $ boolType ()] cond'
    ifexp' <- f ifexp
    let ty = astType ifexp'
    elseexp' <- f elseexp
    -- Force the else branch is of the same type as the if branch
    elseexp'' <- matchType [Exact ty] elseexp'
    ssa <- nextId
    return $ EIf (TypeCheckData pos ty ssa) cond'' ifexp' elseexp''
typeCheckExpr' f (EIdent pos ident) = do
    sym<-getSym pos (SymVar ident)
    if isGlobal sym then do
        ssa <- nextId
        return $ EGlobalName (TypeCheckData pos (definedType sym) ssa) (qualifiedName sym)
    else do
        let sym_ssa = definedSSA sym
        return $ EResolvedIdent (TypeCheckData pos (definedType sym) sym_ssa) sym_ssa
typeCheckExpr' f (EBinary pos And lhs rhs) = exactBinaryCheck f (boolType ()) pos And lhs rhs
typeCheckExpr' f (EBinary pos Or lhs rhs) = exactBinaryCheck f (boolType ()) pos Or lhs rhs
typeCheckExpr' f (EBinary pos Shl lhs rhs) = exactBinaryCheck f (intType ()) pos Shl lhs rhs
typeCheckExpr' f (EBinary pos Shr lhs rhs) = exactBinaryCheck f (intType ()) pos Shr lhs rhs
typeCheckExpr' f (EBinary pos Add lhs rhs) = do
    ref_lhs <- f lhs
    ref_rhs <- f rhs
    buildKetExpr pos Add ref_lhs ref_rhs
typeCheckExpr' f (EBinary pos Sub lhs rhs) = do
    ref_lhs <- f lhs
    ref_rhs <- f rhs
    buildKetExpr pos Sub ref_lhs ref_rhs
typeCheckExpr' f (EBinary pos op lhs rhs) = do
    ref_lhs<-f lhs
    ref_rhs<-f rhs
    buildBinaryExpr pos op ref_lhs ref_rhs
typeCheckExpr' f (EUnary pos op lhs) = do
    lhs'<-f lhs
    matched_lhs <- case op of
        Not -> matchType [Exact $ boolType ()] lhs'
        Noti -> matchType [ArrayType $ Exact $ boolType ()] lhs'
        _ -> matchType (map Exact [intType (), doubleType (), complexType ()]) lhs'
    ssa<-nextId
    let return_type = astType matched_lhs
    return $ EUnary (TypeCheckData pos return_type ssa) op matched_lhs
typeCheckExpr' f (EKet pos coeff base) = do
    coeff' <- f coeff
    coeff'' <- matchType [Exact $ complexType ()] coeff'
    ssa <- nextId
    return $ EKet (TypeCheckData pos (ketType ()) ssa) coeff'' base
typeCheckExpr' f (EVector pos vec) = do
    ssa <- nextId
    return $ EVector (TypeCheckData pos (ketType ()) ssa) vec
typeCheckExpr' f (ESubscript pos base (ERange epos lo hi step)) = do
    base' <- f base
    base'' <- matchType [AnyList] base'
    ssa <- nextId
    let sub_ty = head $ subTypes $ astType base''
    let lo' = case lo of
            Just exp -> exp
            Nothing -> EIntLit epos 0
    let step' = case step of
            Just exp -> exp
            Nothing -> EIntLit epos 1
    let hi' = case hi of
            Just exp -> exp
            Nothing -> case astType base'' of
                    Type () (Array x) [_] | x > 0 -> EIntLit epos x
                    Type () (Array _) [_] -> EArrayLen epos base
    let getSize = case lo' of
                    EIntLit _ lo -> case step' of
                            EIntLit _ step -> case hi' of
                                    -- size = ceil[(hi - lo) / step]
                                    EIntLit _ hi -> -((lo - hi) `div` step)
                                    _ -> -1
                            _ -> -1
                    _ -> -1
    case getSize of
        n | n <= 0 -> do
            -- size = ceil[(hi - lo) / step]
            let size = EBinary epos CeilDiv (EBinary epos Sub hi' lo') step'
            range <- f $ ERange epos (Just lo') (Just size) (Just step')
            let ty = Type () (Array 0) [sub_ty]
            return $ ESubscript (TypeCheckData pos ty ssa) base'' range
        size -> do
            range <- f $ ERange epos (Just lo') (Just $ EIntLit epos size) (Just step')
            let ty = Type () (Array size) [sub_ty]
            return $ ESubscript (TypeCheckData pos ty ssa) base'' range
typeCheckExpr' f (ESubscript pos base offset) = do
    base' <- f base
    base'' <- matchType [AnyList] base'
    offset' <- f offset
    offset'' <- matchType [Exact $ intType ()] offset'
    ssa <- nextId
    in_oracle <- gets inOracle
    let sub_ty = head $ subTypes $ astType base''
    let ty = (if in_oracle || sub_ty == Type () Param [] then sub_ty else refType () sub_ty)
    return $ ESubscript (TypeCheckData pos ty ssa) base'' offset''
typeCheckExpr' f (EModify pos mods callee) = do
    callee' <- f callee
    callee'' <- matchType [AnyGate] callee'
    let adds = sum $ map addedQubits mods
    let extendGate (Type _ (Gate x) sub) = Type () (Gate $ replicate adds (-1) ++ x) sub
    let ty = case astType callee'' of
            g@(Type _ (Gate x) sub) -> extendGate g
            Type _ (Templ n) sub -> do
                let newg = extendGate $ sub !! n
                Type () (Templ n) $ take n sub ++ [newg]
    --traceM $ show ty
    ssa <- nextId
    return $ EModify (TypeCheckData pos ty ssa) mods callee'
typeCheckExpr' f (ETemplate pos callee targs) = do
    callee' <- f callee
    callee'' <- matchType [AnyTempl] callee'
    targs' <- mapM f targs
    let Type _ (Templ n) sub_ty = astType callee''
    when (length targs /= n) $ throwError $ ArgNumberMismatch pos (length targs) n
    targs'' <- zipWithM (\a->matchType [Exact a]) (take n sub_ty) targs'
    let ty = sub_ty !! n
    ssa <- nextId
    return $ ETemplate (TypeCheckData pos ty ssa) callee'' targs''
typeCheckExpr' f (ECall pos callee callArgs) = do
    callee'<-f callee
    callee'' <- matchType [AnyFunc] callee'
    callArgs'<-mapM f callArgs
    let ncall = length callArgs
    let Type _ FuncTy (ret: args) = astType callee''
    when (length args /= ncall) $ throwError $ ArgNumberMismatch pos (length args) ncall
    callArgs'' <- zipWithM (\a->matchType [Exact a]) args callArgs'
    ssa <- nextId
    return $ ECall (TypeCheckData pos ret ssa) callee'' callArgs''
typeCheckExpr' f (EIntLit pos x) = do
    ssa<-nextId
    return $ EIntLit (TypeCheckData pos (intType ()) ssa) x
typeCheckExpr' f (EFloatingLit pos x) = do
    ssa<-nextId
    return $ EFloatingLit (TypeCheckData pos (doubleType ()) ssa) x
typeCheckExpr' f (EImagLit pos x) = do
    ssa<-nextId
    return $ EImagLit (TypeCheckData pos (complexType ()) ssa) x
typeCheckExpr' f (EBoolLit pos x) = do
    ssa<-nextId
    return $ EBoolLit (TypeCheckData pos (boolType ()) ssa) x
typeCheckExpr' f (ECast pos exp ty) = do
    exp' <- f exp
    matchType [Exact $ void ty] exp'
typeCheckExpr' f (ERange pos lo hi Nothing) = do
    let step = Just (EIntLit pos 1)
    f (ERange pos lo hi step)
typeCheckExpr' f (ERange pos lo hi step) = do
    let resolve (Just x) = do {x'<-f x; x''<-matchType [Exact (Type () Int [])] x'; return $ Just x''}
        resolve Nothing = return Nothing
    lo'<-resolve lo
    hi'<-resolve hi
    step'<-resolve step
    ssa<-nextId
    return $ ERange (TypeCheckData pos (Type () IntRange []) ssa) lo' hi' step'
typeCheckExpr' f (ECoreMeasure pos qubit) = do
    qubit'<-f qubit
    ssa<-nextId
    is_qubit <- matchType' [Exact (refType () (qbitType ()))] qubit'
    case is_qubit of
        Just qubit'' -> return $ ECoreMeasure (TypeCheckData pos (boolType ()) ssa) qubit''
        Nothing -> do
            qubit'' <- matchType [Exact $ Type () (Array 0) [qbitType ()]] qubit'
            fun <- f (EIdent pos "__measure_bundle")
            return $ ECall (TypeCheckData pos (intType ()) ssa) fun [qubit'']
typeCheckExpr' f (EList pos lis) = do
    lis' <- mapM f lis
    (lis'', ty) <- case astType $ head lis' of
            Type () Ref [Type () Qbit []] -> do
                lis'' <- mapM (matchType [Exact $ qbitType ()]) lis'
                return (lis'', Type () (Array $ length lis) [qbitType ()])
            _ -> do
                let levels = map (typeToInt . termType . annotationExpr) lis'
                let (min_level, min_idx) = minimum $ zip levels [0..]
                when (min_level < 0) $ throwError $ do
                    let ann = annotationExpr $ lis' !! min_idx
                    UnsupportedType (sourcePos ann) (termType ann)
                let ele_type = intToType min_level
                return (lis', Type () (Array $ length lis) [ele_type])
    ssa <- nextId
    return $ EList (TypeCheckData pos ty ssa) lis''
typeCheckExpr' f x@EDeref{} = error "Unreachable."
typeCheckExpr' f x@EImplicitCast{} = error "Unreachable."
typeCheckExpr' f (ETempVar pos ident) = do
    sym<-getSym pos (SymTempVar ident)
    let ssa = definedSSA sym
    return $ EResolvedIdent (TypeCheckData pos (definedType sym) ssa) ssa
typeCheckExpr' f (ETempArg pos ident) = do
    sym<-getSym pos (SymTempArg ident)
    let ssa = definedSSA sym
    return $ EResolvedIdent (TypeCheckData pos (definedType sym) ssa) ssa
typeCheckExpr' f (EUnitLit pos) = EUnitLit . TypeCheckData pos (unitType ()) <$> nextId
typeCheckExpr' f x@EResolvedIdent{} = error "Unreachable."
typeCheckExpr' f x@EGlobalName{} = error "Unreachable."
typeCheckExpr' f x@EListCast{} = error "Unreachable."
typeCheckExpr' f (EArrayLen pos array) = do
    array' <- f array
    ssa <- nextId
    let ty = termType $ annotationExpr array'
    case ty of
        Type () (Array x) [_] | x > 0 -> return $ EIntLit (TypeCheckData pos (intType ()) ssa) x
        Type () (Array _) [_] -> return $ EArrayLen (TypeCheckData pos (intType ()) ssa) array'
        _ -> throwError $ TypeMismatch pos [AnyList] ty

typeCheckExpr :: Expr Pos -> TypeCheck (Expr TypeCheckData)
typeCheckExpr = fix typeCheckExpr'

okStmt :: Pos->TypeCheckData
okStmt pos = TypeCheckData pos (unitType ()) (-1)


-- Transforms a defvar-type into a ref type.
-- Lists are passed by value and thus are right-values.
definedRefType :: EType->EType
definedRefType x@(Type () (Array _) _) = x
definedRefType x@(Type () Param _) = x
definedRefType x = Type () Ref [x]

typeCheckAST' :: (AST Pos->TypeCheck (AST TypeCheckData))->AST Pos->TypeCheck (AST TypeCheckData)
typeCheckAST' f (NBlock pos lis) = do
    scope
    lis' <- mapM f lis
    unscope
    return $ NBlock (okStmt pos) lis'
typeCheckAST' f (NIf pos cond bthen belse) = do
    cond'<-typeCheckExpr cond
    cond''<-matchType [Exact (boolType ())] cond'
    scope
    bthen'<-mapM f bthen
    unscope
    scope
    belse'<-mapM f belse
    unscope
    return $ NIf (okStmt pos) cond'' bthen' belse'
typeCheckAST' f (NFor pos v r b) = do
    scope
    r'<-typeCheckExpr r
    v'<-defineSym (SymVar v) pos (intType ())
    r''<-matchType [Exact (Type () IntRange [])] r'
    b'<-mapM f b
    unscope
    return $  NResolvedFor (okStmt pos) v' r'' b'
typeCheckAST' f (NEmpty pos) = return $ NEmpty (okStmt pos)
typeCheckAST' f (NPass pos) = return $ NPass (okStmt pos)
typeCheckAST' f (NAssert pos exp Nothing) = do
    exp' <- typeCheckExpr exp
    exp'' <- matchType [Exact (boolType ())] exp'
    return $ NAssert (okStmt pos) exp'' Nothing
typeCheckAST' f (NAssert pos exp _) = error "unreachable"
typeCheckAST' f (NResolvedAssert pos q space) = do
    q' <- typeCheckExpr q
    q'' <- matchType [ArrayType $ Exact (qbitType ())] q'
    return $ NResolvedAssert (okStmt pos) q'' space
typeCheckAST' f (NAssertSpan pos q vecs) = do
    q' <- typeCheckExpr q
    q'' <- matchType [ArrayType $ Exact (qbitType ())] q'
    vecs' <- mapM typeCheckExpr $ head vecs
    return $ NAssertSpan (okStmt pos) q'' [vecs']
typeCheckAST' f (NBp pos) = do
    temp_ssa<-nextId
    let annotation = TypeCheckData pos (unitType ()) temp_ssa
    return $ NBp annotation
typeCheckAST' f (NWhile pos cond body) = error "unreachable"
typeCheckAST' f (NCall pos c@(ECall pos2 callee args)) = do
    callee' <- typeCheckExpr callee
    case astType callee' of
        Type _ (Gate x) sub -> do
            args' <- mapM typeCheckExpr args
            args'' <- if not (null args) && astType (head args') == Type () Param [] then do
                extra <- case head args of
                        i@(EIdent ann name) -> mapM typeCheckExpr [i, EIntLit ann (length name), EIntLit ann (-1)]
                        ESubscript ann i@(EIdent _ name) s -> do
                            i' <- typeCheckExpr i
                            index <- mapM typeCheckExpr [EIntLit ann (length name), s]
                            -- Due to the current implmentation of param, force the type of the first term to be Param
                            return $ i'{annotationExpr=(annotationExpr i'){termType=Type () Param []}} : index
                return $ extra ++ tail args'
                else return args'
            let nsub = length sub
            let require = length x + nsub
            let ncall = length args''
            when (ncall /= require) $ throwError $ ArgNumberMismatch pos require ncall
            classic <- zipWithM (\a -> matchType [Exact a]) sub args''
            if null x then return $ NCall (okStmt pos) $ ECall (okStmt pos2) callee' classic
            else if maximum x == -1 then case astType $ args'' !! nsub of
                Type _ (Array _) [Type _ Qbit []]-> do
                    -- Bundle operation
                    -- Use a for loop to replace original `U(classic, arr1, arr2, ...)`:
                    --   for __i in 0:min(arr1.length, arr2.length, ...):1 {
                    --      U(classic, arr1[__i], arr2[__i], ...);
                    --   }
                    let iter = "__i"
                    let lo = Just $ EIntLit pos2 0
                    let step = Just $ EIntLit pos2 1
                    let len0 = EArrayLen pos2 $ args !! nsub
                    let hi = foldl (\len arg -> EBinary pos Min len $ EArrayLen pos2 arg) len0 $ drop (nsub + 1) args
                    let range = ERange pos2 lo (Just hi) step
                    let eit = EIdent pos2 iter
                    let qargs' = map (\arg -> ESubscript pos2 arg eit) $ drop nsub args
                    let ncall = NCall pos $ ECall pos2 callee $ take nsub args ++ qargs'
                    f $ NFor pos iter range [ncall]
                _ -> do
                    qubits <- zipWithM (\a -> matchType [Exact $ intToQbit () a]) x $ drop nsub args''
                    --traceM $ show qubits
                    return $ NCall (okStmt pos) $ ECall (okStmt pos) callee' $ classic ++ qubits
            else do
                qubits <- zipWithM (\a -> matchType [Exact $ intToQbit () a]) x $ drop nsub args''
                --traceM $ show qubits
                return $ NCall (okStmt pos) $ ECall (okStmt pos) callee' $ classic ++ qubits
        _ -> do
            c' <- typeCheckExpr c
            return $ NCall (okStmt pos) c'
typeCheckAST' f (NDefvar pos defs) = do
    in_oracle <- gets inOracle
    let def_one (ty, name, initializer, length) = do
            let left_type = void ty
            if in_oracle then do
                let sym = SymVar name
                case initializer of
                    Just r -> do
                        r' <- typeCheckExpr r
                        case left_type of
                            Type () (Array 0) [Type () Bool []] -> do
                                r'' <- matchType [ArrayType $ Exact $ boolType ()] r'
                                rid <- setSym sym pos $ annotationExpr r''
                                return (left_type, rid, Just r'')
                            Type () Bool [] -> do
                                r''<-matchType [Exact left_type] r'
                                rid <- setSym sym pos $ annotationExpr r''
                                return (left_type, rid, Just r'')
                            other -> throwError $ UnsupportedType pos other
                    Nothing -> case length of
                        Nothing -> do
                            case left_type of
                                Type () Bool [] -> do
                                    rid <- defineSym sym pos left_type
                                    return (left_type, rid, Nothing)
                                _ -> throwError $ UnsupportedType pos left_type
                        Just elen@(EIntLit _ len) -> do
                            case left_type of
                                Type () (Array 0) [Type () Bool []] -> do
                                    let ty = Type () (Array len) [Type () Bool []]
                                    rid <- defineSym sym pos ty
                                    elen' <- typeCheckExpr elen
                                    return (ty, rid, Just elen')
                                other -> throwError $ UnsupportedType pos other
                        _ -> throwError $ UnsupportedLeftSide pos
            else do
                (i', ty') <- case initializer of
                    Just r -> do
                        r' <- typeCheckExpr r
                        case left_type of
                            Type () (Array llen) [lsub] -> do
                                let right_type = termType $ annotationExpr r'
                                case right_type of
                                    Type () (Array rlen) [rsub] -> do
                                        when (typeToInt lsub < 0) $ throwError $ UnsupportedType pos left_type
                                        let EList ann sub = r'
                                        sub' <- mapM (matchType [Exact lsub]) sub
                                        let llen' = case llen of
                                                0 -> rlen
                                                _ -> llen
                                        return (Just $ EList ann sub', Type () (Array llen') [lsub])
                                    _ -> throwError $ TypeMismatch pos [Exact left_type] right_type
                            _ -> do
                                r''<-matchType [Exact left_type] r'
                                return (Just r'', definedRefType left_type)
                    Nothing -> case length of
                        Nothing -> return (Nothing, definedRefType left_type)
                        Just len@(EIntLit ann v) -> do
                            let array_ty = Type () (Array v) $ subTypes left_type
                            return (Nothing, array_ty)
                        Just len -> do
                            len' <- typeCheckExpr len
                            len'' <- matchType [Exact $ intType ()] len'
                            return (Just len'', left_type)
                s <- defineSym (SymVar name) pos ty'
                return (ty', s, i')
    defs'<-mapM def_one defs
    return $ NResolvedDefvar (okStmt pos) defs'
typeCheckAST' f (NAssign pos lhs rhs op) = do
    rhs'<-typeCheckExpr rhs
    in_oracle <- gets inOracle
    lhs' <- typeCheckExpr lhs
    (if in_oracle then (do
        lhs'' <- matchType [Exact $ boolType ()] lhs'
        case lhs of
            EIdent lpos ident -> do
                let sym = SymVar ident
                sym_data <- getSym lpos sym
                let lhs_ty = definedType sym_data
                case lhs_ty of
                    Type () Bool [] -> do
                        rhs'' <- matchType [Exact lhs_ty] rhs'
                        setSym sym lpos $ annotationExpr rhs''
                        return $ NAssign (okStmt pos) lhs'' rhs'' AssignEq
                    other -> throwError $ UnsupportedType pos other
            ESubscript {} -> do
                rhs'' <- matchType [Exact $ boolType ()] rhs'
                return $ NAssign (okStmt pos) lhs'' rhs'' AssignEq
            _ -> throwError $ UnsupportedLeftSide $ annotationExpr lhs) else (do
        let doAssign lhs' rhs' = do
                lhs'' <- matchType [AnyRef] lhs'
                let Type () Ref [lhs_ty] = astType lhs''
                when (ty lhs_ty==Qbit) $ throwError $ ViolateNonCloningTheorem pos
                rhs'' <- matchType [Exact lhs_ty] rhs'
                return $ NAssign (okStmt pos) lhs'' rhs'' AssignEq
        case op of
            AssignEq -> case astType lhs' of
                    Type () (Array len) [Type () Qbit []] -> do
                        rhs'' <- matchType [Exact $ ketType ()] rhs'
                        return $ NAssign (okStmt pos) lhs' rhs'' AssignEq
                    _ -> doAssign lhs' rhs'
            AddEq -> do
                let lhs_ty = termType $ annotationExpr lhs'
                case lhs_ty of
                    Type () (Array _) [Type () Qbit []] -> do
                        lhs'' <- matchType [Exact $ Type () (Array 0) [qbitType ()]] lhs'
                        rhs'' <- matchType [Exact $ Type () (Array 0) [qbitType ()]] rhs'
                        call_id <- nextId
                        callee <- typeCheckExpr $ EIdent pos "__add"
                        let ecall = ECall (TypeCheckData pos (unitType ()) call_id) callee [rhs'', lhs'']
                        return $ NCall (okStmt pos) ecall
                    _ -> do
                        eadd <- buildBinaryExpr pos Add lhs' rhs'
                        lhs2 <- typeCheckExpr lhs
                        doAssign lhs2 eadd
            SubEq -> do
                let lhs_ty = termType $ annotationExpr lhs'
                case lhs_ty of
                    Type () (Array _) [Type () Qbit []] -> do
                        lhs'' <- matchType [Exact $ Type () (Array 0) [qbitType ()]] lhs'
                        rhs'' <- matchType [Exact $ Type () (Array 0) [qbitType ()]] rhs'
                        call_id <- nextId
                        callee <- typeCheckExpr $ EIdent pos "__sub"
                        let ecall = ECall (TypeCheckData pos (unitType ()) call_id) callee [rhs'', lhs'']
                        return $ NCall (okStmt pos) ecall
                    _ -> do
                        esub <- buildBinaryExpr pos Sub lhs' rhs'
                        lhs2 <- typeCheckExpr lhs
                        doAssign lhs2 esub))
typeCheckAST' f (NGatedef pos lhs rhs _) = error "unreachable"
typeCheckAST' f (NReturn pos expr) = do
    expr' <- typeCheckExpr expr
    return $ NReturn (okStmt pos) expr'
typeCheckAST' f (NResolvedInit pos qubit state) = do
    qubit' <- typeCheckExpr qubit
    qubit'' <- matchType [FixedArray $ Exact $ qbitType ()] qubit'
    return $ NResolvedInit (okStmt pos) qubit'' state
typeCheckAST' f (NCorePrint pos val) = do
    val'<-typeCheckExpr val
    val''<-matchType [
        Exact (intType ()),
        Exact (doubleType ()),
        Exact (complexType ())
        ] val'
    return $ NCorePrint (okStmt pos) val''
typeCheckAST' f (NCoreMeasure pos qubit) = do
    qubit'<-typeCheckExpr qubit
    return $ NCoreMeasure (okStmt pos) qubit'
typeCheckAST' f (NContinue _) = error "unreachable"
typeCheckAST' f (NBreak _) = error "unreachable"
typeCheckAST' f (NResolvedFor {}) = error "unreachable"
typeCheckAST' f (NResolvedGatedef pos name matrix size _) = error "unreachable"
typeCheckAST' f (NWhileWithGuard pos cond body break) = do
    cond'<-typeCheckExpr cond
    cond''<-matchType [Exact (boolType ())] cond'
    break'<-typeCheckExpr break
    break''<-matchType [Exact (boolType ())] break'
    body'<-mapM f body
    return $ NWhileWithGuard (okStmt pos) cond'' body' break''
typeCheckAST' f (NCase pos base stats isket) = do
    scope
    stats' <- mapM f stats
    unscope
    return $ NCase (okStmt pos) base stats' isket
typeCheckAST' f (NSwitch pos cond cases defau) = do
    let isUnitary :: Bool -> AST TypeCheckData -> TypeCheck ()
        isUnitary True _ = return ()
        isUnitary False NCall{} = return ()
        isUnitary False other = throwError $ UnsupportedStatement $ sourcePos $ annotationAST other
    cond' <- typeCheckExpr cond
    cond'' <- matchType [ArrayType $ Exact $ qbitType (), Exact $ intType ()] cond'
    let is_int = astType cond'' == intType ()
    let check :: AST Pos -> TypeCheck (AST TypeCheckData)
        check cas@(NCase pos _ _ isket) = do
            when (is_int && isket) $ throwError $ TypeMismatch pos [Exact $ intType ()] $ ketType ()
            when (not is_int && not isket) $ throwError $ TypeMismatch pos [Exact $ ketType ()] $ intType ()
            cas' <- f cas
            mapM_ (isUnitary is_int) $ body cas'
            return cas'
    cases' <- mapM check cases
    scope
    defau' <- mapM f defau
    unscope
    mapM_ (isUnitary is_int) defau'
    return $ NSwitch (okStmt pos) cond'' cases' defau'
typeCheckAST' f (NResolvedProcedureWithRet {}) = error "unreachable"
typeCheckAST' f (NJumpToEndOnFlag pos flag) = do
    flag'<-typeCheckExpr flag
    flag''<-matchType [Exact (boolType ())] flag'
    return $ NJumpToEndOnFlag (okStmt pos) flag''
typeCheckAST' f (NJumpToEnd pos) = return $ NJumpToEnd (okStmt pos)
typeCheckAST' f (NTempvar pos def) = do
    let def_one (ty, id, initializer) = do
            i'<-case initializer of
                Just r->do
                        r'<-typeCheckExpr r
                        r''<-matchType [Exact (void ty)] r'
                        return $ Just r''
                Nothing -> return Nothing
            s<-defineSym (SymTempVar id) pos $ definedRefType $ void ty
            return (definedRefType $ void ty, s, i')
    def'<-def_one def
    return $ NResolvedDefvar (okStmt pos) [def']
typeCheckAST' f x@NResolvedExternGate{} = return $ fmap okStmt x
typeCheckAST' f NExternGate{} = error "unreachable"
typeCheckAST' f NProcedureWithDerive{} = error "unreachable"
typeCheckAST' f NResolvedDefvar{} = error "unreachable"
typeCheckAST' f NGlobalDefvar {} = error "unreachable"
typeCheckAST' f NOracle{} = error "unreachable"
typeCheckAST' f NOracleTable{} = error "unreachable"

typeCheckAST :: AST Pos -> TypeCheck (AST TypeCheckData)
typeCheckAST = fix typeCheckAST'


isQuantumData :: Type ann -> Bool
isQuantumData (Type _ Qbit _) = True
isQuantumData (Type _ (Array 0) _) = False
isQuantumData (Type _ (Array _) [Type _ Qbit _]) = True
isQuantumData _ = False

{- Iterate the types from right to left.

A legal type list has all the quantum ones right most, such as [double, qbit, qbit[3]]
-}
countGateSize :: [Type ann] -> Maybe [Int]
countGateSize = (snd<$>) . foldr go (Just (False, [])) where
    go _ Nothing = Nothing
    go (isQuantumData->True) (Just (True, _)) = Nothing
    go (isQuantumData->False) (Just (True, x)) = Just (True, x)
    go (isQuantumData->False) (Just (False, x)) = Just (True, x)
    go (Type _ Qbit []) (Just (False, xs)) = Just (False, -1 : xs)
    go (Type _ (Array x) [Type _ Qbit _]) (Just (False, xs)) = Just (False, x : xs)
    go (isQuantumData->True) (Just (False, x)) = Nothing

argType :: Type Pos->Ident->TypeCheck EType
argType ty i = argType' (annotation ty) i ty
argType' :: Pos -> Ident -> LType -> TypeCheck EType
argType' pos i ty = case ty of
    Type _ Int [] -> return $ void ty
    Type _ Double [] -> return $ void ty
    Type _ Bool [] -> return $ void ty
    Type _ Qbit [] -> return $ Type () Ref [void ty]
    Type _ (Array _) [a] -> return $ void ty
    Type _ Unit [] -> return $ void ty
    Type _ FuncTy subTy -> do
        subTy' <- mapM (argType' pos i) subTy
        return $ Type () FuncTy subTy'
    Type _ Param [] -> return $ void ty
    _ -> throwError $ BadProcedureArgType pos (void ty, i)

typeCheckToplevel :: Bool -> String -> [AST Pos]-> Bool -> TypeCheck ([TCAST], SymbolTableLayer, Int)
typeCheckToplevel isMain prefix ast qcis = do

    (resolved_defvar, varlist)<-flip runStateT [] $ do
        mapM (\node->case node of
                NDefvar pos def -> do
                    -- Create separate scope to prevent cross-reference.
                    lift scope
                    p<-lift $ typeCheckAST node
                    let (NResolvedDefvar a defs') = p
                    s<-lift unscope
                    modify' (MultiMap.map (\x->x{isGlobal=True}) s:)
                    let (ty, _, _, _) = head def
                    let node' = case ty of
                            Type _ Param [] -> NResolvedDefParam (okStmt pos) $ map (\(_, param, _, _) -> (prefix ++ param, param)) def
                            Type _ (Array _) [Type _ Param []] -> NResolvedDefParam (okStmt pos) $ map (\(_, param, _, _) -> (prefix ++ param, param)) def
                            other -> NGlobalDefvar a (zipWith (\(a1, a2, a3) (_, a4, _, _) ->(a1, a2, prefix ++ a4, a3)) defs' def)
                    return $ Right node'
                x -> return $ Left x
            ) ast
    -- Add all vars into table.
    let vars=concatMap MultiMap.toList varlist
    let qualifiedVars = map (\tup -> do
            let sym = fst tup
            let qualified = prefix ++ getSymbolName sym
            let qualifiedData = (snd tup){qualifiedName = qualified}
            (sym, qualifiedData)) vars
    mapM_ (uncurry addSym) $ reverse qualifiedVars

    -- Resolve all gates and procedures.
    resolved_headers<-mapM (\case
            Right x->return (Right x)
            Left (NResolvedGatedef pos name matrix size qir) -> do
                defineGlobalSym prefix name pos (Type () (Gate size) []) False False
                return $ Right (NResolvedGatedef (okStmt pos) (prefix ++ name) matrix size qir)
            Left (NExternGate pos name extra size qirname) -> do
                extra'<-mapM (argType' pos "<anonymous>") extra
                defineGlobalSym prefix name pos (Type () (Gate size) extra') False False
                return $ Right $ NResolvedExternGate (okStmt pos) (prefix ++ name) (fmap void extra) size qirname
            Left (NOracleTable pos name source value size) -> do
                defineGlobalSym prefix name pos (Type () (Gate size) []) False False
                return $ Right (NOracleTable (okStmt pos) (prefix ++ name) (prefix ++ source) value size)
            Left (NPermuation pos name size value) -> do
                let maxn = 2 ^ size
                if maxn /= length value then
                        throwError $ BadPermutationShape pos
                    else do
                        let check :: Set.Set Int -> LExpr -> TypeCheck (Set.Set Int)
                            check set (EIntLit pos v) = (if inRange (0, maxn - 1) v
                                then (if Set.member v set
                                    then throwError $ BadPermutationValue pos
                                    else return $ Set.insert v set)
                                else throwError $ BadPermutationValue pos)
                            check _ other = throwError $ BadPermutationValue $ annotationExpr other
                        foldM check Set.empty value
                defineGlobalSym prefix name pos (Type () (Gate $ replicate size $ -1) []) False False
                value' <- mapM typeCheckExpr value
                return $ Right (NPermuation (okStmt pos) (prefix ++ name) size value')
            Left (NProcedureWithDerive pos ty name targs args body derive ret) -> do
                -- check arg types and return types
                ty' <- case derive of
                        Just DeriveOracle -> case ty of
                            -- TODO: consider the case that returned length depends on a template argument
                            Type _ (Array _) [Type _ Bool []] -> return $ void ty
                            _ -> throwError $ BadProcedureReturnType pos (void ty, name)
                        _ -> case ty of
                            Type _ Int [] -> return $ void ty
                            Type _ Unit [] -> return $ void ty
                            Type _ Double [] -> return $ void ty
                            Type _ Bool [] -> return $ void ty
                            _ -> throwError $ BadProcedureReturnType pos (void ty, name)
                let new_args = if name == "main" && null args
                    then [(Type pos (Array 0) [intType pos], "main$par1"), (Type pos (Array 0) [doubleType pos], "main$par2")] else args

                -- Translate array length from identifier to its index, e.g. `foo<int N>(int a[N])` to `foo<int N>(int a[-2])`
                -- 1. Generate a Hash map whose key is an identifier and value is its index.
                let processTargs :: (Map.Map String Int, [Type ()]) -> ((LType, String), Int) -> TypeCheck (Map.Map String Int, [Type ()])
                    processTargs (tmap, tys) ((ty@(Type pos Int []), ident), idx) = return (Map.insert ident idx tmap, tys ++ [void ty])
                    processTargs (tmap, tys) ((Type pos (Gate _) sub, ident), _) = do
                        (tmap', sub') <- foldM processTargs (tmap, []) $ map ((, -1) . (, ident)) sub
                        gate <- case countGateSize sub' of
                            Nothing -> throwError $ BadProcedureArgType pos (void ty, ident)
                            Just x -> return (Type () (Gate x) $ take (length sub' - length x) sub')
                        return (tmap', tys ++ [gate])
                    processTargs (tmap, tys) ((ty, _), idx) | idx >= 0 = throwError $ UnsupportedType pos $ void ty
                    processTargs (tmap, tys) ((Type pos (TArray i) subTy, ident), idx) = do
                        case Map.lookup i tmap of
                                Nothing -> throwError $ UndefinedSymbol pos (SymVar i)
                                Just x -> return (tmap, tys ++ [Type () (Array $ -2 - x) $ map void subTy])
                    processTargs (tmap, tys) ((ty, _), _) = return (tmap, tys ++ [void ty])
                (tmap, targs') <- foldM processTargs (Map.empty, []) $ zip targs [0..]
                -- 2. Use the Hash map to translate the arguments.
                let getMemIndex :: LType -> TypeCheck (LType, Int)
                    getMemIndex (Type pos (TArray i) subTy) = do
                        case Map.lookup i tmap of
                                Nothing -> throwError $ UndefinedSymbol pos (SymVar i)
                                Just x -> return (Type pos (Array $ -2 - x) subTy, x)
                    getMemIndex other = return (other, -1)
                idx_and_args <- mapM (getMemIndex . fst) new_args
                let args' = map fst idx_and_args
                let length_idx = if null targs' then Nothing else Just $ map snd idx_and_args
                args'' <- zipWithM argType args' $ map snd new_args

                -- Generate the type of the procedure and store it to the symbol table
                pty <- case derive of
                    Nothing -> do
                        return $ Type () FuncTy (ty': args'')
                    Just DeriveGate -> case countGateSize args' of
                        Nothing -> throwError $ BadGateSignature pos
                        Just x -> return $ Type () (Gate x) $ take (length args' - length x) args''
                    Just DeriveOracle -> do
                        -- Map the oracle function type to qubit type
                        let bool2int :: LType -> TypeCheck Int
                            bool2int (Type _ (Array x) [Type _ Bool []]) = return x
                            bool2int ty = throwError $ UnsupportedType (annotationType ty) $ void ty
                        x <- mapM bool2int args'
                        ty' <- bool2int ty
                        return $ Type () (Gate $ x ++ [ty']) []
                full_ty <- case length targs of
                        0 -> return pty
                        n -> do
                            return $ Type () (Templ n) $ targs' ++ [pty]
                --traceM $ show full_ty
                defineGlobalSym prefix name (annotation ty) full_ty False $ isJust derive

                case derive of
                    Just DeriveOracle -> do
                        scope
                        let tys = targs' ++ map void args'
                            syms = map snd $ targs ++ args
                            tuple = zip tys syms
                        ids <- mapM (\(ty, i) -> defineSym (SymVar i) pos ty) tuple
                        modify' (\x->x{inOracle = True})
                        body' <- mapM typeCheckAST body
                        modify' (\x->x{inOracle = False})
                        unscope
                        return $ Right (NResolvedProcedureWithRet (okStmt pos) ty' (prefix ++ name) (zip tys ids) body' Nothing Nothing length_idx $ Just DeriveOracle)
                    _ -> do
                        let procName = case name of {"main" -> "main"; x -> prefix ++ name}
                        let arg_tuple = zip args'' (fmap snd new_args)
                        let targ_tuple = zip targs' $ map snd targs
                        return $ Left (pos, ty', procName, arg_tuple, body, ret, targ_tuple, derive, length_idx)
            Left (NProcedureInstantiated pos ty name targs args body derive ori ret) -> do
                targs' <- mapM typeCheckExpr targs
                sym <- getSym pos (SymVar name)
                case definedType sym of
                    Type _ (Templ n) sub -> do
                        when (n /= length targs) $ throwError $ ArgNumberMismatch pos n $ length targs
                        targs'' <- zipWithM (\a -> matchType [Exact a]) (take n sub) targs'
                        mangled <- foldM (\name expr -> do
                            case astType expr of
                                Type _ Int [] -> do
                                    case expr of
                                        EIntLit ann val -> return $ name ++ "__" ++ show val
                                        other -> throwError $ UnsupportedStatement pos
                                Type _ (Gate x) sub -> do
                                    case expr of
                                        EGlobalName ann ident -> return $ name ++ "__" ++ ident
                                        other -> throwError $ UnsupportedStatement pos
                                other -> throwError $ UnsupportedType pos other
                            ) (prefix ++ name) targs''
                        return $ Left (pos, void ty, mangled, map (first void) args, body, ret, [], derive, Nothing)
                    other -> throwError $ TypeMismatch pos [AnyTempl] other
            Left x -> error $ "unreachable " ++ show x
        ) resolved_defvar

    -- Finally, resolve procedure bodies.
    -- Note that we need to store byval-passed values (e.g. int) into new variables.
    body<-mapM (\case
        Right x->return x
        Left (pos, ty, func_name, args, body, ret@(ETempVar pret ret_id), targs, derive, length_idx)-> do
            scope
            -- resolve return value
            ret_var<-case ty of
                Type _ Unit [] -> do
                    s<-defineSym (SymTempVar ret_id) pret (Type () Ref [ty])
                    return Nothing
                _ -> do
                    s<-defineSym (SymTempVar ret_id) pret (Type () Ref [ty])
                    return $ Just (Type () Ref [ty], s)
            -- resolve template arguments
            targs' <- mapM (\(ty, i) -> do
                s <- defineSym (SymVar i) pos ty
                return (ty, s)
                ) targs
            -- resolve args
            let processArg :: BuiltinType -> Ident -> StateT [AST Pos] TypeCheck (Type (), Int)
                processArg ty i = do
                    temp_arg <- lift nextId -- Temporary argument
                    s <- lift $ defineSym (SymTempArg temp_arg) pos $ Type () ty []
                    -- Leave the good name for defvar.
                    --real_arg<-lift $ defineSym (SymVar i) pos (refType () (intType ()))
                    modify' (++[NDefvar pos [(Type pos ty [], i, Just $ ETempArg pos temp_arg, Nothing)]])
                    return (Type () ty [], s)
            (args', new_tempvars)<-flip runStateT [] $ mapM (\(ty, i)->case ty of
                Type _ Bool [] -> processArg Bool i
                Type _ Int [] -> processArg Int i
                Type _ Double [] -> processArg Double i
                x -> do
                    s<-lift $ defineSym (SymVar i) pos x
                    return (ty, s)
                    ) args
            -- resolve body
            body'<-mapM typeCheckAST (new_tempvars++body)
            ret''<-case ty of
                Type _ Unit [] -> return Nothing
                _ -> do
                    ret'<-typeCheckExpr ret
                    ret''<-matchType [Exact ty] ret'
                    return $ Just ret''
            unscope
            return $ NResolvedProcedureWithRet (okStmt pos) ty func_name (targs' ++ args') body' ret'' ret_var length_idx derive
        Left _ -> error "unreachable"
        ) resolved_headers

    m <- gets mainDefined
    when (isMain && not m) $ throwError MainUndefined

    -- Extract global symbols
    symtable <- gets symbolTable
    let topLayer = last $ init symtable
    let lis = MultiMap.toList topLayer
    let globalLis = filter (isGlobal . snd) lis
    let globalLayer = MultiMap.fromList globalLis
    ssaId <- gets ssaAllocator
    return (body, globalLayer, ssaId)

typeCheckTop :: Bool -> String -> [LAST] -> SymbolTableLayer -> Int -> Bool -> Either TypeCheckError ([TCAST], SymbolTableLayer, Int)
typeCheckTop isMain prefix ast stl ssaId qcis= do
    let env = TypeCheckEnv [MultiMap.empty, stl] ssaId False False 0
    evalState (runExceptT $ typeCheckToplevel isMain prefix ast qcis) env

-- TODO: unification-based type check and type inference.

data TyAtom= TInt | TQbit | TBool | TDouble | TComplex | TList | TKnownList Int | TUser String | TRange | TGate Int | TRef | TVal | TFunc deriving (Show, Eq)
data Ty = TMultiple { tyArgs :: [Ty] } | TAtom TyAtom | TVar Int deriving (Show, Eq)

