module ISQ.Lang.OraclePass where
import ISQ.Lang.ISQv2Grammar
import Control.Monad.Except
import Control.Monad.Extra (concatMapM)
import Control.Monad (void)
import Control.Monad.State.Lazy (evalState, get, put, State)
import Data.Bits hiding (And, Or)
import Data.Complex
import Data.List (null)
--import Data.Either.Combinators (mapRight)
import qualified Data.Map.Lazy as Map

data OracleError =
      BadOracleShape Pos
    | BadOracleValue Pos
    deriving (Eq, Show)

mangleGate x = "$_ISQ_GATEDEF_"++x

pow2 :: Int->Int
pow2 0 = 1
pow2 x = 2 * pow2 (x-1)

toBit :: Int->[Int]
toBit 0 = [0]
toBit 1 = [1]
toBit x = toBit (div x 2) ++ toBit (mod x 2)

toBit' :: Int->Int->[Int]
toBit' x l = replicate (l - length (toBit x)) 0 ++ toBit x

foldConstantValue :: LExpr->Either OracleError Int
foldConstantValue x@(EIntLit _ val) = Right val
foldConstantValue x = Left $ BadOracleValue (annotationExpr x)

getDerivingArgs :: Int -> Pos -> [(LType, Ident)]
getDerivingArgs a ann = map (\x -> (Type ann Qbit [], [x])) (take a ['a'..'z'])

getMValue :: [Int]->Int->Int->[Int]
getMValue fx m y = [u | u <- zipWith (\ v x -> (if toBit' v m !! y == 1 then x else - 1)) fx [0..], u > -1]

getOracleVale :: [Int]->Int->[[Int]]
getOracleVale fx m = map (getMValue fx m) (take m [0..])

--replace orcale with deriving gate (use multi-contrl-x)
passOracle' :: LAST -> Either OracleError [LAST]
passOracle' (NOracle ann name n m fx) = if pow2 n /= length fx then Left $ BadOracleShape ann else do
    fv <- mapM foldConstantValue fx
    if all (\x -> x < pow2 m) fv then
        let proc_name = mangleGate name
        in Right [
            NProcedureWithDerive ann (unitType ann) proc_name [] (getDerivingArgs (n+m) ann) [] Nothing $ EUnitLit ann,
            NOracleTable ann name proc_name (getOracleVale fv m) $ replicate (n + m) $ -1
            ]
    else Left $ BadOracleValue ann

passOracle' x = Right [x]

passOracle :: [LAST] -> Either OracleError [LAST]
passOracle = concatMapM passOracle'
