#include <iostream>
#include <unordered_set>
#include <llvm/ADT/SmallPtrSet.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/Math/IR/Math.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/IR/Value.h>
#include <mlir/Pass/PassRegistry.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/FoldUtils.h>
#include "isq/Operations.h"
#include "isq/OpVerifier.h"
#include "isq/QTypes.h"
#include "isq/passes/Passes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "isq/GateDefTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "isq/passes/Mem2Reg.h"
#include "isqast/Operations.h"
#include "mockturtle/algorithms/simulation.hpp"
#include "mockturtle/networks/xag.hpp"
#include "isq/passes/LHRSOracle.hpp"
#include "caterpillar/synthesis/strategies/bennett_mapping_strategy.hpp"
#include "caterpillar/synthesis/strategies/eager_mapping_strategy.hpp"
#include "isq/passes/GreedyPebbling.hpp"
#include "mockturtle/algorithms/lut_mapping.hpp"
#include "mockturtle/algorithms/collapse_mapped.hpp"
#include "mockturtle/networks/klut.hpp"
#include "mockturtle/views/mapping_view.hpp"
#undef X // ghack.hpp fix

namespace isq{
namespace ir{
namespace passes{

const char *SEPARATOR = "__";
const char *TEMPLATE = "template";
const char *DERIVING = "deriving";
const char *DEF_SUFFIX = "GATEDEF";
const char *LENGTH_DEF = "length_def";
const char *LOGIC = "logic";

void genDefGateOp(mlir::func::FuncOp op, mlir::PatternRewriter& rewriter, mlir::ModuleOp rootModule){
    if (op->hasAttr(TEMPLATE) || !op->hasAttr(DERIVING)) {
        return;
    }
    auto func_ty = op.getFunctionType();
    int qtype = false;
    llvm::SmallVector<mlir::Attribute> par;
    llvm::SmallVector<int> sizes;
    for (int i=0; i<func_ty.getNumInputs(); i++) {
        auto ty = func_ty.getInput(i);
        if (auto memty = llvm::dyn_cast_or_null<mlir::MemRefType>(ty)) {
            assert(memty.getElementType().isa<QStateType>() && memty.getRank() == 1 && "Bad memref type");
            sizes.push_back(memty.getDimSize(0));
            qtype = true;
        }
        if (!qtype) {
            par.push_back(mlir::TypeAttr::get(ty));
        }
    }
    auto ctx = rewriter.getContext();
    auto gate_ty = mlir::TypeAttr::get(isq::ast::GateType::get(ctx, sizes));
    mlir::SmallVector<mlir::Attribute> def;
    auto old_name = op.getName().str();
    auto new_name = old_name + SEPARATOR + DEF_SUFFIX;
    rewriter.startRootUpdate(op);
    op.setName(new_name);
    op->removeAttr(DERIVING);
    rewriter.finalizeRootUpdate(op);
    auto gate_def = GateDefinition::get(ctx, mlir::StringAttr::get(ctx, "decomposition_raw"), mlir::FlatSymbolRefAttr::get(ctx, new_name));
    def.push_back(gate_def);
    rewriter.setInsertionPointToStart(rootModule.getBody());
    auto defgate = rewriter.create<isq::ast::DefgateOp>(op.getLoc(), gate_ty, mlir::StringAttr::get(ctx, old_name), 
        nullptr, mlir::ArrayAttr::get(ctx, def), mlir::ArrayAttr::get(ctx, par));
}

class PrepareInstantiated : public mlir::OpRewritePattern<mlir::func::FuncOp>{
    mlir::ModuleOp rootModule;
public:
    PrepareInstantiated(mlir::MLIRContext *ctx, mlir::ModuleOp module): mlir::OpRewritePattern<mlir::func::FuncOp>(ctx, 1), rootModule(module){
    }
    mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override{
        genDefGateOp(op, rewriter, rootModule);
        return mlir::failure();
    }
};

class RewriteTempl : public mlir::OpRewritePattern<isq::ast::TemplOp>{
    mlir::ModuleOp rootModule;
public:
    RewriteTempl(mlir::MLIRContext *ctx, mlir::ModuleOp module): mlir::OpRewritePattern<isq::ast::TemplOp>(ctx, 1), rootModule(module){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::TemplOp op, mlir::PatternRewriter& rewriter) const override{
        // Omit the operation in a template
        mlir::Operation *par = op->getParentOp();
        while (!llvm::isa<mlir::func::FuncOp>(par)) {
            par = par->getParentOp();
            assert(par);
        }
        mlir::func::FuncOp func = llvm::dyn_cast_or_null<mlir::func::FuncOp>(par);
        if (func->hasAttr(TEMPLATE)) {
            return mlir::failure();
        }

        auto callee = op.getCallee();
        auto def = callee.getDefiningOp();
        assert(def && "Cannot find defining op!");
        auto ast_func = llvm::dyn_cast_or_null<isq::ast::FuncOp>(def);
        assert(ast_func && "Not defined by isqast.func");
        mlir::FlatSymbolRefAttr symbol = ast_func.getValueAttr();
        mlir::func::FuncOp temp = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op, symbol);
        assert(temp && "Undefined symbol!");
        mlir::FunctionType function_type = llvm::dyn_cast_or_null<mlir::FunctionType>(callee.getType());
        assert(function_type && "Not a function type");
        auto args = op.getArgs();
        int arg_size = args.size();
        //assert(inst_size + args.size() == function_type.getNumInputs() && "Unmatched argument number");

        // Get mangled name
        std::string mangled(symbol.getValue().str());
        mlir::SmallVector<mlir::Operation*> ops;
        for (int i=0; i<arg_size; i++) {
            //assert(function_type.getInput(i) == insts[i].getType() && "Unmatched type");
            mlir::Operation *arg_op = args[i].getDefiningOp();

            if (auto con = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(arg_op)) {
                mangled += SEPARATOR + std::to_string(con.value());
            }
            else if (auto func2 = llvm::dyn_cast_or_null<isq::ast::FuncOp>(arg_op)) {
                mangled += SEPARATOR + func2.getValue().str();
            }
            else {
                return mlir::failure();
            }
            ops.push_back(arg_op);
        }

        // Instantiate template
        auto ctx = rewriter.getContext();
        auto mangled_attr = mlir::StringAttr::get(ctx, mangled);
        mlir::func::FuncOp inst_func = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op, mangled_attr);
        auto inst_defgate = mlir::SymbolTable::lookupNearestSymbolFrom<isq::ast::DefgateOp>(op, mangled_attr);
        if (!inst_func && !inst_defgate) {
            instantiateFunction(temp, rewriter, mangled, ops);
        }
        mlir::Type ty;
        if (auto inst_func = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op, mangled_attr)) {
            ty = inst_func.getFunctionType();
        }
        else if (auto inst_defgate = mlir::SymbolTable::lookupNearestSymbolFrom<isq::ast::DefgateOp>(op, mangled_attr)) {
            mlir::SmallVector<mlir::Type> in;
            for (auto attr : inst_defgate.getParameters()) {
                in.push_back(llvm::dyn_cast<mlir::TypeAttr>(attr).getValue());
            }
            ty = rewriter.getFunctionType(in, inst_defgate.getType());
        } else {
            assert(false && "Cannot find symbol");
        }
        rewriter.setInsertionPoint(op);
        auto isq_func = rewriter.create<isq::ast::FuncOp>(op.getLoc(), ty, mangled);
        op.getResult().replaceAllUsesWith(isq_func);
        rewriter.eraseOp(op);
        return mlir::success();
    }

    void instantiateFunction(mlir::func::FuncOp templ, mlir::PatternRewriter& rewriter, std::string name, mlir::SmallVector<mlir::Operation*> ops) const {
        int ops_size = ops.size();

        auto instantiatedLengthType = [&](mlir::Type ty)->mlir::Type {
            if (auto memref = mlir::dyn_cast_or_null<mlir::MemRefType>(ty)) {
                int ms = memref.getMemorySpaceAsInt();
                if (ms == 0) {
                    return memref;
                }
                assert(ops_size >= ms && "Cannot find the array size template");
                auto con = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(ops[ms - 1]);
                assert(con && "Cannot find the array size template");
                return mlir::MemRefType::get({con.value()}, memref.getElementType(), memref.getLayout());
            }
            else if (auto gate = mlir::dyn_cast_or_null<isq::ast::GateType>(ty)) {
                ;
            }
            return ty;
        };

        auto ctx = rewriter.getContext();
        mlir::func::FuncOp func = templ;
        if (ops_size > 0) {
            func = templ.clone();
            func->removeAttr(TEMPLATE);
            func.setName(name);
            auto block = func.getBlocks().begin();
            for (int i=0; i<ops_size; i++) {
                rewriter.setInsertionPointToStart(&*block);
                auto cloned = rewriter.clone(*ops[i]);
                block->getArgument(i).replaceAllUsesWith(cloned->getResult(0));
            }
            auto ty = templ.getFunctionType();
            mlir::SmallVector<mlir::Type> ins;
            for (int i=ops_size; i<ty.getNumInputs(); i++) {
                ins.push_back(ty.getInput(i));
            }
            auto length_attr = llvm::dyn_cast_or_null<mlir::ArrayAttr>(func->getAttr(LENGTH_DEF));
            if (length_attr) {
                int lsize = length_attr.size();
                assert(lsize == ins.size() && "Illegal length_def array size");
                for (int i=0; i<ins.size(); i++) {
                    auto int_attr = llvm::dyn_cast_or_null<mlir::IntegerAttr>(length_attr[i]);
                    assert(int_attr && "Bad attribute format");
                    int value = int_attr.getInt();
                    if (value < 0) {
                        continue;
                    }
                    auto oldty = llvm::dyn_cast_or_null<mlir::MemRefType>(ins[i]);
                    assert(oldty && "The original type is not memref");
                    auto con = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(ops[value]);
                    assert(con && "Cannot find the array size template");
                    ins[i] = mlir::MemRefType::get({con.value()}, oldty.getElementType(), oldty.getLayout());
                }
                func->removeAttr(LENGTH_DEF);
            }
            for (int i=ops_size; i<ty.getNumInputs(); i++) {
                block->getArgument(i).setType(ins[i - ops_size]);
            }
            func.setType(mlir::FunctionType::get(ctx, ins, ty.getResults()));
            auto bv = llvm::BitVector(ops_size, true);
            bv.resize(ty.getNumInputs());
            block->eraseArguments(bv);
            mlir::ModuleOp rootModule = this->rootModule;
            rewriter.setInsertionPointToStart(rootModule.getBody());
            rewriter.insert(func);
            genDefGateOp(func, rewriter, rootModule);
        }
    }
};

GateType ReplaceGateType(mlir::MLIRContext *ctx, mlir::Type ty){
    if (auto fun = llvm::dyn_cast_or_null<mlir::FunctionType>(ty)) {
        assert(fun.getNumInputs() == 0);
        ty = fun.getResult(0);
    }
    if (auto gate = llvm::dyn_cast_or_null<isq::ast::GateType>(ty)) {
        int total = 0;
        for (int s : gate.getSizes()) {
            total += abs(s);
        }
        return GateType::get(ctx, total, GateTrait::General);
    }
    auto gate = llvm::dyn_cast_or_null<GateType>(ty);
    if (!gate) {
        ty.dump();
    }
    assert(gate);
    return gate;
}

class RewriteCall : public mlir::OpRewritePattern<isq::ast::CallOp>{
    mlir::ModuleOp rootModule;
public:
    RewriteCall(mlir::MLIRContext *ctx, mlir::ModuleOp module): mlir::OpRewritePattern<isq::ast::CallOp>(ctx, 1), rootModule(module){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::CallOp call, mlir::PatternRewriter& rewriter) const override{
        auto ctx = rewriter.getContext();
        auto loc = call.getLoc();
        mlir::Operation::operand_range args = call.getArgs();
        for (auto arg : args) {
            //arg.setType(instantiatedLengthType(arg.getType()));
        }

        auto callee = call.getCallee();
        auto def = callee.getDefiningOp();
        if (!def) {
            // When the callee is not defined by an Op, it must be defined as a block argument.
            // This is the result of translating higher-order parameters, such as the `bar` in
            //     `unit foo(unit bar()){...}`
            // Therefore, we use `func.call_indirect` to implement it.
            rewriter.setInsertionPoint(call);
            auto ind = rewriter.create<mlir::func::CallIndirectOp>(loc, callee, call.getArgs());
            for (int i=0; i<call.getNumResults(); i++) {
                call.getResult(i).replaceAllUsesWith(ind.getResult(i));
            }
            rewriter.eraseOp(call);
            return mlir::failure();
        }
        isq::ast::FuncOp ast_func = nullptr;
        if (auto dec = llvm::dyn_cast_or_null<isq::ast::DecorateOp>(def)) {
            ast_func = llvm::dyn_cast_or_null<isq::ast::FuncOp>(dec.getGate().getDefiningOp());
        } else {
            ast_func = llvm::dyn_cast_or_null<isq::ast::FuncOp>(def);
        }
        if (!ast_func) {
            call.emitError("No def op. (Maybe too many template instantiation.)");
            return mlir::failure();
        }
        mlir::FlatSymbolRefAttr symbol = ast_func.getValueAttr();
        int use_size = -1;
        isq::ast::GateType gate_ty = nullptr;
        int narg = args.size();
        mlir::Value use;
        if (mlir::func::FuncOp temp = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(call, symbol)) {
            if (temp->hasAttr(LOGIC)) {
                return mlir::failure();
            }
            mlir::FunctionType function_type = llvm::dyn_cast_or_null<mlir::FunctionType>(callee.getType());
            assert(function_type && "Not a function type");
            assert(narg == function_type.getNumInputs() && "Unmatched argument number");

            rewriter.setInsertionPoint(call);
            auto func_call = rewriter.create<mlir::func::CallOp>(loc, symbol, function_type.getResults(), args);
            for (int i=0; i<call.getNumResults(); i++) {
                call.getResult(i).replaceAllUsesWith(func_call.getResult(i));
            }
        }
        else if (auto defgate = mlir::SymbolTable::lookupNearestSymbolFrom<isq::ast::DefgateOp>(call, symbol)) {
            use_size = defgate.getParameters().size();
            gate_ty = defgate.getType();
            use = rewriter.create<UseGateOp>(loc, gate_ty, symbol, args.take_front(use_size));
        } else {
            assert(false && "unrecognized op");
        }

        if (auto dec = llvm::dyn_cast_or_null<isq::ast::DecorateOp>(def)) {
            assert(use_size >= 0);
            mlir::SmallVector<int> sizes;
            auto dec_ctrl = dec.getCtrl();
            for (int i=0; i<dec_ctrl.size(); i++) {
                sizes.push_back(-1);
            }
            auto oldg = llvm::dyn_cast_or_null<isq::ast::GateType>(use.getType());
            assert(oldg);
            for (auto s : oldg.getSizes()) {
                sizes.push_back(s);
            }
            gate_ty = isq::ast::GateType::get(ctx, sizes);
            use = rewriter.create<isq::ast::DecorateOp>(loc, gate_ty, use, dec.getAdjoint(), dec_ctrl);
        }
        if (narg == use_size) {
            rewriter.create<isq::ast::ApplyGPhase>(loc, use);
        }
        else if (gate_ty) {
            auto sizes = gate_ty.getSizes();
            llvm::SmallVector<mlir::Value> states;
            int nsize = sizes.size();
            for (int i=0; i<nsize; i++) {
                int nq = abs(sizes[i]);
                for (int j=0; j<nq; j++) {
                    mlir::Value idx = rewriter.create<mlir::arith::ConstantIndexOp>(loc, j);
                    auto loaded = rewriter.create<mlir::affine::AffineLoadOp>(loc, args[narg - nsize + i], mlir::ArrayRef<mlir::Value>({idx}));
                    states.push_back(loaded);
                }
            }
            llvm::SmallVector<mlir::Type> state_ty(states.size(), QStateType::get(ctx));
            auto apply = rewriter.create<isq::ast::ApplyGateOp>(loc, state_ty, use, states);
            int idx = 0;
            for (int i=0; i<nsize; i++) {
                for (int j=0; j<abs(sizes[i]); j++) {
                    mlir::Value offset = rewriter.create<mlir::arith::ConstantIndexOp>(loc, j);
                    rewriter.create<mlir::affine::AffineStoreOp>(loc, apply.getResult(idx++), args[narg - nsize + i], mlir::ArrayRef<mlir::Value>({offset}));
                }
            }
        }
        rewriter.eraseOp(call);
        return mlir::failure();
    }

    GateType astGateToIrGate(mlir::MLIRContext *ctx, mlir::Type ty) const {
        if (auto ir = llvm::dyn_cast_or_null<GateType>(ty)) {
            return ir;
        }
        auto ast = llvm::dyn_cast_or_null<isq::ast::GateType>(ty);
        int sum = 0;
        for (int s : ast.getSizes()){
            sum += abs(s);
        }
        return GateType::get(ctx, sum, GateTrait::General);
    }
};

class RewriteLogic : public mlir::OpRewritePattern<mlir::func::FuncOp>{
    mlir::ModuleOp rootModule;
public:
    RewriteLogic(mlir::MLIRContext *ctx, mlir::ModuleOp module): mlir::OpRewritePattern<mlir::func::FuncOp>(ctx, 1), rootModule(module){
    }
    mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override{
        if (op->hasAttr(TEMPLATE) || !op->hasAttr(LOGIC)) {
            return mlir::failure();
        }

        // Build the XAG.
        mockturtle::xag_network xag; // the xag to be built
        auto hash = [](const std::pair<std::string, int> &p){
            return std::hash<std::string>()(p.first) * 31 + std::hash<int>()(p.second);
        };
        std::unordered_map<std::pair<std::string, int>, mockturtle::xag_network::signal, decltype(hash)> symbol_table(8, hash);

        // A helper function that get the SSA identifier linked to a value
        mlir::AsmState state(op);
        auto value2str = [&](mlir::Value value) -> std::string {
            return std::to_string((size_t)llvm::hash_value(value.getImpl()));
        };

        // Create input signals based on the function inputs.
        unsigned int input_num = op.getNumArguments();
        for (int i=0; i<input_num; i++) {
            mlir::Value arg = op.getArgument(i);
            std::string str = value2str(arg);
            int width = getBitWidth(arg);
            for (int j=0; j<width; j++) {
                symbol_table[{str, j}] = xag.create_pi();
            }
        }

        // Binary operator processing template
        auto binary = [&]<typename T>(mlir::Value lhs, mlir::Value rhs, mlir::Value res,
            T create) {
            std::string lname = value2str(lhs);
            std::string rname = value2str(rhs);
            std::string res_name = value2str(res);
            symbol_table[{res_name, -1}] = (xag.*create)(symbol_table.at({lname, -1}), symbol_table.at({rname, -1}));
        };

        // Binary vector operator processing template
        auto vec_binary = [&]<typename T>(mlir::Value lhs, mlir::Value rhs, mlir::Value res,
            T create) {
            std::string lname = value2str(lhs);
            std::string rname = value2str(rhs);
            std::string res_name = value2str(res);
            int width = getBitWidth(res);
            for (int j=0; j<width; j++) {
                symbol_table[{res_name, j}] = (xag.*create)(symbol_table.at({lname, j}), symbol_table.at({rname, j}));
            }
        };

        // Store the value of for-loop iterator
        mlir::DenseMap<mlir::Value, int> value_map;
        std::function<int(mlir::Value)> valueToInt = [&](mlir::Value value) {
            if (value_map.contains(value)) {
                return value_map[value];
            }
            auto op = value.getDefiningOp<mlir::arith::ConstantOp>();
            assert(op);
            mlir::IntegerAttr attr = op.getValue().dyn_cast_or_null<mlir::IntegerAttr>();
            assert(attr);
            return (int)attr.getInt();
        };

        std::function<bool(mlir::Operation&)> process_op = [&](mlir::Operation &it)->bool {
            if (isq::ast::NotOp notop = llvm::dyn_cast<isq::ast::NotOp>(it)) {
                std::string operand = value2str(notop.getOperand());
                std::string res = value2str(notop.getResult());
                symbol_table[{res, -1}] = xag.create_not(symbol_table.at({operand, -1}));
            }
            else if (isq::ast::NotvOp notvop = llvm::dyn_cast<isq::ast::NotvOp>(it)) {
                std::string operand = value2str(notvop.getOperand());
                mlir::Value result = notvop.getResult();
                std::string res = value2str(result);
                int width = getBitWidth(result);
                for (int i=0; i<width; i++) {
                    symbol_table[{res, i}] = xag.create_not(symbol_table.at({operand, i}));
                }
            }
            else if (isq::ast::AndOp binop = llvm::dyn_cast<isq::ast::AndOp>(it)) {
                binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_and);
            }
            else if (isq::ast::OrOp binop = llvm::dyn_cast<isq::ast::OrOp>(it)) {
                binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_or);
            }
            else if (isq::ast::XorOp binop = llvm::dyn_cast<isq::ast::XorOp>(it)) {
                binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_xor);
            }
            else if (isq::ast::XnorOp binop = llvm::dyn_cast<isq::ast::XnorOp>(it)) {
                binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_xnor);
            }
            else if (isq::ast::AndvOp binop = llvm::dyn_cast<isq::ast::AndvOp>(it)) {
                vec_binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_and);
            }
            else if (isq::ast::OrvOp binop = llvm::dyn_cast<isq::ast::OrvOp>(it)) {
                vec_binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_or);
            }
            else if (isq::ast::XorvOp binop = llvm::dyn_cast<isq::ast::XorvOp>(it)) {
                vec_binary(binop.getLhs(), binop.getRhs(), binop.getResult(), &mockturtle::xag_network::create_xor);
            }
            // Only process boolean values (with type `i1`), leaving other arith.constant (array index) untouched.
            else if (mlir::arith::ConstantOp con = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(it)) {
                mlir::BoolAttr attr = con.getValue().dyn_cast_or_null<mlir::BoolAttr>();
                if (attr) {
                    symbol_table[{value2str(con.getResult()), -1}] = xag.get_constant(attr.getValue());
                }
            }
            else if (mlir::memref::LoadOp load = llvm::dyn_cast<mlir::memref::LoadOp>(it)) {
                mlir::Value voffset = *load.getIndices().begin();
                int ioffset = valueToInt(voffset);
                std::string res = value2str(load.getResult());
                std::string memref = value2str(load.getMemref());
                symbol_table[{res, -1}] = symbol_table.at({memref, ioffset});
            }
            else if (mlir::memref::StoreOp store = llvm::dyn_cast<mlir::memref::StoreOp>(it)) {
                mlir::Value voffset = *store.getIndices().begin();
                int ioffset = valueToInt(voffset);
                std::string val = value2str(store.getValue());
                std::string memref = value2str(store.getMemref());
                symbol_table[{memref, ioffset}] = symbol_table.at({val, -1});
            }
            else if (auto ret = llvm::dyn_cast<mlir::func::ReturnOp>(it)) {
                mlir::Value oprand = ret.getOperand(0);
                std::string str = value2str(oprand);
                int width = getBitWidth(oprand);
                for (int i=0; i<width; i++) {
                    mockturtle::xag_network::signal sig = symbol_table.at({str, i});
                    xag.create_po(sig);
                }
            }
            else if (mlir::affine::AffineForOp forop = llvm::dyn_cast_or_null<mlir::affine::AffineForOp>(it)) {
                int lower = forop.getConstantLowerBound();
                int upper = forop.getConstantUpperBound();
                int step = forop.getStep();
                mlir::Value it = forop.getInductionVar();

                // Mimic the behavior of a for-loop
                for (int i=lower; i<upper; i+=step) {
                    value_map[it] = i;
                    for (auto &op : forop.getLoopBody().getOps()) {
                        process_op(op);
                    }
                }
                value_map.erase(it);
            }
            else if (mlir::scf::ExecuteRegionOp erop = llvm::dyn_cast_or_null<mlir::scf::ExecuteRegionOp>(it)) {
                for (auto &op : erop.getOps()) {
                    if (!process_op(op)) {
                        return false;
                    }
                }
            }
            else if (auto forop = llvm::dyn_cast_or_null<isq::ast::ForOp>(it)) {
                return false;
            }
            return true;
        };

        // Process each statement in the funciton body.
        for (mlir::Operation &it : op.getRegion().getOps()) {
            if (!process_op(it)) {
                return mlir::failure();
            }
        }

        // Convert XAG to quantum circuit. 
        // caterpillar::eager_mapping_strategy<mockturtle::xag_network> strategy;
        caterpillar::greedy_pebbling_mapping_strategy<mockturtle::xag_network> strategy;
        tweedledum::netlist<caterpillar::stg_gate> circ;
        caterpillar::logic_network_synthesis_stats stats;
        tweedledum::stg_from_pprm stg_fn;
        caterpillar::logic_network_synthesis_params ps;
        caterpillar::detail::logic_network_synthesis_impl_oracle<tweedledum::netlist<caterpillar::stg_gate>, 
            mockturtle::xag_network, tweedledum::stg_from_pprm> impl( circ, xag, strategy, stg_fn, ps, stats );
        impl.run();
        
        // Construct MLIR-style circuit. 
        mlir::MLIRContext *ctx = op.getContext();
        mlir::Location loc = op.getLoc(); // The location of the oracle function in the source code.

        // Construct function signature.
        mlir::SmallVector<::mlir::Type> argtypes;
        mlir::SmallVector<::mlir::Type> returntypes;
        isq::ir::QStateType qstate = isq::ir::QStateType::get(ctx);
        mlir::AffineExpr d0, s0, s1;
        mlir::bindDims(ctx, d0);
        mlir::bindSymbols(ctx, s0, s1);
        auto affine_map = mlir::AffineMap::get(1, 2, d0 * s1 + s0);
        for (int i=0; i<input_num; i++) {
            mlir::Value arg = op.getArgument(i);
            int width = getBitWidth(arg);
            mlir::MemRefType memref_i_qstate = mlir::MemRefType::get(mlir::ArrayRef<int64_t>{width}, qstate, affine_map);
            argtypes.push_back(memref_i_qstate);
        }
        auto po_num = xag.num_pos();
        mlir::MemRefType memref_o_qstate = mlir::MemRefType::get(mlir::ArrayRef<int64_t>{po_num}, qstate, affine_map);
        argtypes.push_back(memref_o_qstate);
        mlir::FunctionType functype = mlir::FunctionType::get(ctx, argtypes, returntypes);

        // Debug infomation. 
        /*
        std::ofstream fout("/mnt/c/Users/14796/Desktop/workspace/isqv2/debugoutput.txt");
        std::streambuf *oldcout;
        oldcout = std::cout.rdbuf(fout.rdbuf());
        std::cout << "******xag description*******" << std::endl;
        xag.foreach_node( [&]( auto node ) {
            std::cout << "index: " << xag.node_to_index(node) << std::endl;
            xag.foreach_fanin(node, [&]( auto child ) {
                std::cout << "  child: " << child.index << std::endl;
                std::cout << "  complemented: " << (child.complement ? "y" : "n") << std::endl;
            });
        } );
        xag.foreach_pi( [&]( auto pi ) {
            std::cout << "pi: " << pi << std::endl;
        } );
        xag.foreach_po( [&]( auto node, auto index ) {
            std::cout << "po: " << xag.get_node(node) << (((mockturtle::xag_network::signal)node.data).complement ? " complemented" : "") << std::endl;
            xag.foreach_fanin(index, [&]( auto child ) {
                std::cout << "  child: " << child.index << std::endl;
                std::cout << "  complemented: " << (child.complement ? "y" : "n") << std::endl;
            });
        } );
        
        std::cout << "******circuit description*******" << std::endl;
        std::cout << "num_gates: " << circ.num_gates() << std::endl;
        circ.foreach_cgate( [&]( auto n ) {
            std::cout << n.gate << std::endl;
            n.gate.foreach_control( [&]( auto c ) {
                std::cout << "  control: " << c << " " << (c.is_complemented() ? "complemented" : "") << std::endl;
            } );
            n.gate.foreach_target( [&]( auto t ) {
                std::cout << "  target: " << t << " " << (t.is_complemented() ? "complemented" : "") << std::endl;
            } );
        } );
        std::cout << "num_qubits: " << circ.num_qubits() << std::endl;
        */
        // Create a FuncOp that represent the quantum circuit.
        mlir::func::FuncOp funcop = rewriter.create<mlir::func::FuncOp>(loc, op.getSymName(), functype);
        funcop->setAttr(DERIVING, rewriter.getUnitAttr());
        mlir::Block *entry_block = funcop.addEntryBlock(); // Arguments are automatically created based on the function signature.
        mlir::OpBuilder builder(entry_block, entry_block->begin());
        
        // Load arguments. 
        std::unordered_map<uint32_t, int> qubit_to_wire;
        std::vector<mlir::Value> memref;
        std::vector<mlir::Value> offset;
        int idx = 0;
        for (int i = 0; i <= input_num; i++) {
            auto arg = entry_block->getArgument(i);
            int width = getBitWidth(arg);
            std::string str = (i == input_num ? "" : value2str(op.getArgument(i)));
            for (int j = 0; j < width; j++) {
                memref.push_back(arg);
                mlir::arith::ConstantIndexOp index = builder.create<mlir::arith::ConstantIndexOp>(loc, j);
                offset.push_back(index);
                // Construct qubit-to-wire mapping. 
                auto xagnode = symbol_table[{str, j}];
                uint32_t qubit = (i == input_num ? stats.o_indexes[j] : stats.i_indexes[xag.pi_index(xagnode.index)]);
                qubit_to_wire[qubit] = idx++;
            }
        }

        // Ancilla allocation. 
        int ancilla_num = circ.num_qubits() - qubit_to_wire.size();
        mlir::MemRefType memref_ancilla_qstate = mlir::MemRefType::get(mlir::ArrayRef<int64_t>{ancilla_num}, qstate);
        mlir::memref::AllocOp ancillas = builder.create<mlir::memref::AllocOp>(loc, memref_ancilla_qstate);
        for (int i = 0; i < ancilla_num; i++) {
            memref.push_back(ancillas);
            mlir::arith::ConstantIndexOp index = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
            offset.push_back(index);
        }
        circ.foreach_cqubit( [&] ( tweedledum::qubit_id qubit_id ) {
            uint32_t qubit = qubit_id.index();
            if (qubit_to_wire.find(qubit) == qubit_to_wire.end()) {
                qubit_to_wire[qubit] = idx++;
            }
        } );
        
        // Load the quantum gates. The last argument is the parameters of the gate, e.g., `theta` for Rz(theta, q);
        mlir::Value x_gate = builder.create<isq::ir::UseGateOp>(loc, isq::ir::GateType::get(ctx, 1, GateTrait::General),
            mlir::FlatSymbolRefAttr::get(ctx, "X"), mlir::ValueRange{}).getResult();
        mlir::Value cnot_gate = builder.create<isq::ir::UseGateOp>(loc, isq::ir::GateType::get(ctx, 2, GateTrait::General),
            mlir::FlatSymbolRefAttr::get(ctx, "CNOT"), mlir::ValueRange{}).getResult();
        mlir::Value toffoli_gate = builder.create<isq::ir::UseGateOp>(loc, isq::ir::GateType::get(ctx, 3, GateTrait::General),
            mlir::FlatSymbolRefAttr::get(ctx, "Toffoli"), mlir::ValueRange{}).getResult();

        // Gates application template
        auto apply_x = [&](tweedledum::qubit_id index) {
            int idx = qubit_to_wire[index];
            auto load = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[idx], mlir::ArrayRef<mlir::Value>{offset[idx]});
            isq::ir::ApplyGateOp applied_x = builder.create<isq::ir::ApplyGateOp>(loc, mlir::ArrayRef<mlir::Type>{qstate},
                x_gate, mlir::ArrayRef<mlir::Value>(load));
            builder.create<mlir::memref::StoreOp>(loc, applied_x.getResult(0), memref[idx], mlir::ValueRange{offset[idx]});
        };

        auto apply_cnot = [&](tweedledum::qubit_id cindex, tweedledum::qubit_id tindex) {
            int cidx = qubit_to_wire[cindex];
            int tidx = qubit_to_wire[tindex];
            if (cindex.is_complemented()) apply_x(cidx);
            auto cload = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[cidx], mlir::ValueRange{offset[cidx]});
            auto tload = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[tidx], mlir::ValueRange{offset[tidx]});
            isq::ir::ApplyGateOp applied_cnot = builder.create<isq::ir::ApplyGateOp>(loc, mlir::ArrayRef<mlir::Type>{qstate, qstate},
                cnot_gate, mlir::ArrayRef<mlir::Value>({cload, tload}));
            builder.create<mlir::memref::StoreOp>(loc, applied_cnot.getResult(0), memref[cidx], mlir::ValueRange{offset[cidx]});
            builder.create<mlir::memref::StoreOp>(loc, applied_cnot.getResult(1), memref[tidx], mlir::ValueRange{offset[tidx]});
            if (cindex.is_complemented()) apply_x(cindex);
        };

        auto apply_toffoli = [&](tweedledum::qubit_id cindex_1, tweedledum::qubit_id cindex_2, tweedledum::qubit_id tindex) {
            int cidx1 = qubit_to_wire[cindex_1];
            int cidx2 = qubit_to_wire[cindex_2];
            int tidx = qubit_to_wire[tindex];
            if (cindex_1.is_complemented()) apply_x(cindex_1);
            if (cindex_2.is_complemented()) apply_x(cindex_2);
            auto cload_1 = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[cidx1], mlir::ValueRange{offset[cidx1]});
            auto cload_2 = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[cidx2], mlir::ValueRange{offset[cidx2]});
            auto tload = builder.create<mlir::memref::LoadOp>(loc, qstate, memref[tidx], mlir::ValueRange{offset[tidx]});
            isq::ir::ApplyGateOp applied_toffoli = builder.create<isq::ir::ApplyGateOp>(loc, mlir::ArrayRef<mlir::Type>{qstate, qstate, qstate},
                toffoli_gate, mlir::ArrayRef<mlir::Value>({cload_1, cload_2, tload}));
            builder.create<mlir::memref::StoreOp>(loc, applied_toffoli.getResult(0), memref[cidx1], mlir::ValueRange{offset[cidx1]});
            builder.create<mlir::memref::StoreOp>(loc, applied_toffoli.getResult(1), memref[cidx2], mlir::ValueRange{offset[cidx2]});
            builder.create<mlir::memref::StoreOp>(loc, applied_toffoli.getResult(2), memref[tidx], mlir::ValueRange{offset[tidx]});
            if (cindex_1.is_complemented()) apply_x(cindex_1);
            if (cindex_2.is_complemented()) apply_x(cindex_2);
        };

        // Apply gates to qstates. The last argument is the qstates to be applied on.
        circ.foreach_cgate( [&]( auto n ) {
            if (n.gate.is(tweedledum::gate_set::pauli_x)) {
                //std::cout << "apply x gate on " << n.gate.targets()[0] << std::endl;
                apply_x(n.gate.targets()[0]);
            } else if (n.gate.is(tweedledum::gate_set::cx)) {
                //std::cout << "apply cnot gate from " << n.gate.controls()[0] << " to " << n.gate.targets()[0] << std::endl;
                apply_cnot(n.gate.controls()[0], n.gate.targets()[0]);
            } else if (n.gate.is(tweedledum::gate_set::mcx)) {
                //std::cout << "apply toffoli gate from " << n.gate.controls()[0] << ", " <<  n.gate.controls()[1] << " to " << n.gate.targets()[0] << std::endl;
                apply_toffoli(n.gate.controls()[0], n.gate.controls()[1], n.gate.targets()[0]);
            } else {
                assert(false && "unknown gate");
            }
        } );

        builder.create<mlir::memref::DeallocOp>(loc, ancillas);
        builder.create<mlir::func::ReturnOp>(loc); // dummy terminator
        genDefGateOp(funcop, rewriter, rootModule);

        rewriter.eraseOp(op); // Remove original logic.func op
        return mlir::success();
    }
private:
    int getBitWidth(mlir::Value val) const {
        auto mem_type = val.getType().dyn_cast<mlir::MemRefType>();
        assert(mem_type && "Returned value is not of MemRefType");
        return mem_type.getDimSize(0);
    }
};

class UnrollFor : public mlir::OpRewritePattern<isq::ast::ForOp>{
public:
    UnrollFor(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::ForOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::ForOp op, mlir::PatternRewriter& rewriter) const override{
        mlir::Operation *par = op->getParentOp();
        while (!llvm::isa<mlir::func::FuncOp>(par)) {
            par = par->getParentOp();
            assert(par);
        }
        mlir::LogicalResult res = mlir::failure();
        mlir::func::FuncOp func = llvm::dyn_cast_or_null<mlir::func::FuncOp>(par);
        if (func->hasAttr(TEMPLATE)) {
            return res;
        }
        auto lbv = op.getLb();
        auto lb_op = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(lbv.getDefiningOp());
        auto hbv = op.getHb();
        auto hb_op = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(hbv.getDefiningOp());
        auto stepv = op.getStep();
        auto step_op = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(stepv.getDefiningOp());
        rewriter.setInsertionPoint(op);
        auto loc = op.getLoc();
        auto &region = op.getBody();
        if (!lb_op || !hb_op || !step_op) {
            return mlir::failure();
        }
        // Unroll the loop if the bounds are known values
        for (int i=lb_op.value(); i<hb_op.value(); i+=step_op.value()) {
            mlir::IRMapping irmap;
            if (region.getNumArguments() > 0) {
                auto idx = rewriter.create<mlir::arith::ConstantIndexOp>(loc, i);
                irmap.map(region.getArgument(0), idx.getResult());
            }
            for (auto &b: region.getOps()) {
                if (llvm::isa<isq::ast::YieldOp>(b)) {
                    break;
                }
                rewriter.clone(b, irmap);
            }
        }
        rewriter.eraseOp(op);
        return mlir::success();
    }
};

int processLambdaOps(llvm::iterator_range<mlir::Region::OpIterator> ops, mlir::PatternRewriter& rewriter,
    mlir::DenseMap<mlir::Value, int> &value_map, std::unordered_map<std::string, isq::ast::LetrecOp> &letrec_map){
    std::function<int(mlir::Value)> valueToInt = [&](mlir::Value value)->int {
        if (value_map.contains(value)) {
            return value_map[value];
        }
        auto op = value.getDefiningOp<mlir::arith::ConstantOp>();
        if (!op) {
            throw -1;
        }
        mlir::IntegerAttr attr = op.getValue().dyn_cast_or_null<mlir::IntegerAttr>();
        assert(attr);
        return attr.getInt();
    };
    auto it = ops.begin();
    while (it != ops.end()) {
        auto &op = *it++;
        if (mlir::arith::ConstantIndexOp con_index = llvm::dyn_cast_or_null<mlir::arith::ConstantIndexOp>(op)) {
            value_map[con_index.getResult()] = con_index.value();
        }
        else if (mlir::arith::ConstantIntOp con = llvm::dyn_cast_or_null<mlir::arith::ConstantIntOp>(op)) {
            value_map[con.getResult()] = con.value();
        }
        else if (auto add = llvm::dyn_cast_or_null<mlir::arith::AddIOp>(op)) {
            int lhs = valueToInt(add.getLhs());
            int rhs = valueToInt(add.getRhs());
            value_map[add.getResult()] = lhs + rhs;
        }
        else if (auto add = llvm::dyn_cast_or_null<mlir::arith::SubIOp>(op)) {
            int lhs = valueToInt(add.getLhs());
            int rhs = valueToInt(add.getRhs());
            value_map[add.getResult()] = lhs - rhs;
        }
        else if (auto add = llvm::dyn_cast_or_null<mlir::arith::MulIOp>(op)) {
            int lhs = valueToInt(add.getLhs());
            int rhs = valueToInt(add.getRhs());
            value_map[add.getResult()] = lhs * rhs;
        }
        else if (auto add = llvm::dyn_cast_or_null<mlir::arith::DivSIOp>(op)) {
            int lhs = valueToInt(add.getLhs());
            int rhs = valueToInt(add.getRhs());
            value_map[add.getResult()] = lhs / rhs;
        }
        else if (auto add = llvm::dyn_cast_or_null<mlir::arith::FloorDivSIOp>(op)) {
            int lhs = valueToInt(add.getLhs());
            int rhs = valueToInt(add.getRhs());
            value_map[add.getResult()] = lhs / rhs;
        }
        else if (auto rem = llvm::dyn_cast_or_null<mlir::arith::RemSIOp>(op)) {
            int lhs = valueToInt(rem.getLhs());
            int rhs = valueToInt(rem.getRhs());
            value_map[rem.getResult()] = lhs % rhs;
        }
        else if (auto ipow = llvm::dyn_cast_or_null<mlir::math::IPowIOp>(op)) {
            int lhs = valueToInt(ipow.getLhs());
            int rhs = valueToInt(ipow.getRhs());
            value_map[ipow.getResult()] = pow(lhs, rhs) + 0.5;
        }
        else if (auto comp = llvm::dyn_cast_or_null<mlir::arith::CmpIOp>(op)) {
            int lhs = valueToInt(comp.getLhs());
            int rhs = valueToInt(comp.getRhs());
            mlir::arith::CmpIPredicate predicate = comp.getPredicate();
            bool v;
            switch (predicate) {
            case mlir::arith::CmpIPredicate::eq:
                v = lhs == rhs;
                break;
            case mlir::arith::CmpIPredicate::ne:
                v = lhs != rhs;
                break;
            case mlir::arith::CmpIPredicate::slt:
                v = lhs < rhs;
                break;
            case mlir::arith::CmpIPredicate::sle:
                v = lhs <= rhs;
                break;
            case mlir::arith::CmpIPredicate::sgt:
                v = lhs > rhs;
                break;
            case mlir::arith::CmpIPredicate::sge:
                v = lhs >= rhs;
                break;
            default:
                comp.dump();
                assert(false && "Unexpected predicate");
            };
            value_map[comp.getResult()] = v;
        }
        else if (auto asrt = llvm::dyn_cast_or_null<AssertOp>(op)) {
            assert(valueToInt(asrt.getCond()));
        }
        else if (auto cast = llvm::dyn_cast_or_null<mlir::arith::IndexCastOp>(op)) {
            value_map[cast.getResult()] = valueToInt(cast.getIn());
        }
        else if (auto ret = llvm::dyn_cast_or_null<isq::ast::ReturnOp>(op)) {
            return valueToInt(ret.getArg());
        }
        else if (auto let = llvm::dyn_cast_or_null<isq::ast::LetOp>(op)) {
            value_map[let.getResult()] = processLambdaOps(let.getBody().getOps(), rewriter, value_map, letrec_map);
        }
        else if (auto letrec = llvm::dyn_cast_or_null<isq::ast::LetrecOp>(op)) {
            letrec_map[letrec.getSymNameAttr().str()] = letrec;
            value_map[letrec.getResult()] = processLambdaOps(letrec.getEval().getOps(), rewriter, value_map, letrec_map);
        }
        else if (llvm::isa<isq::ast::FuncOp>(op)) {
            continue;
        }
        else if (auto call = llvm::dyn_cast_or_null<isq::ast::CallOp>(op)) {
            auto func = llvm::dyn_cast_or_null<isq::ast::FuncOp>(call.getCallee().getDefiningOp());
            assert(func && "cannot find defining FuncOp");
            auto name = func.getValue().str();
            assert(letrec_map.contains(name) && "cannot find defining LetrecOp");
            auto letrec = letrec_map[name];
            auto pars = call.getArgs();
            auto &init = letrec.getInit();
            auto nargs = init.getNumArguments();
            mlir::DenseMap<mlir::Value, int> sub_map;
            assert(pars.size() == nargs && "Wrong parameter number");
            for (int i=0; i<nargs; i++) {
                sub_map[init.getArgument(i)] = valueToInt(pars[i]);
            }
            value_map[call.getResult(0)] = processLambdaOps(init.getOps(), rewriter, sub_map, letrec_map);
        }
        else if (auto ifop = llvm::dyn_cast_or_null<isq::ast::IfOp>(op)) {
            if (valueToInt(ifop.getCond())) {
                value_map[ifop.getResult()] = processLambdaOps(ifop.getIf().getOps(), rewriter, value_map, letrec_map);
            } else {
                value_map[ifop.getResult()] = processLambdaOps(ifop.getElse().getOps(), rewriter, value_map, letrec_map);
            }
        }
        else {
            op.dump();
            assert(false && "Unrecogonized operation!");
        }
    }
    assert(false && "Should return in an isq::ast::ReturnOp");
    return -1;
}

class RewriteLet : public mlir::OpRewritePattern<isq::ast::LetOp>{
public:
    RewriteLet(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::LetOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::LetOp let, mlir::PatternRewriter& rewriter) const override{
        auto par = let->getParentOp();
        if (llvm::isa<isq::ast::LetOp, isq::ast::LetrecOp>(par)) {
            return mlir::failure();
        }
        mlir::DenseMap<mlir::Value, int> value_map;
        std::unordered_map<std::string, isq::ast::LetrecOp> letrec_map;
        auto ops = let.getBody().getOps();
        int ret;
        try {
            ret = processLambdaOps(ops, rewriter, value_map, letrec_map);
        }
        catch (...) {
            return mlir::failure();
        }
        rewriter.setInsertionPoint(let);
        auto con = rewriter.create<mlir::arith::ConstantIndexOp>(let.getLoc(), ret);
        let.replaceAllUsesWith(con.getResult());
        rewriter.eraseOp(let);
        return mlir::success();
    }
};

class RewriteLetrec : public mlir::OpRewritePattern<isq::ast::LetrecOp>{
public:
    RewriteLetrec(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::LetrecOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::LetrecOp letrec, mlir::PatternRewriter& rewriter) const override{
        if (llvm::isa<isq::ast::LetOp, isq::ast::LetrecOp>(letrec->getParentOp())) {
            return mlir::failure();
        }
        mlir::DenseMap<mlir::Value, int> value_map;
        std::unordered_map<std::string, isq::ast::LetrecOp> letrec_map;
        letrec_map[letrec.getSymNameAttr().str()] = letrec;
        int ret;
        try {
            ret = processLambdaOps(letrec.getEval().getOps(), rewriter, value_map, letrec_map);
        }
        catch (...) {
            return mlir::failure();
        }
        rewriter.setInsertionPoint(letrec);
        auto con = rewriter.create<mlir::arith::ConstantIndexOp>(letrec.getLoc(), ret);
        letrec.replaceAllUsesWith(con.getResult());
        rewriter.eraseOp(letrec);
        return mlir::success();
    }
};

class RewriteFor : public mlir::OpRewritePattern<isq::ast::ForOp>{
public:
    RewriteFor(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::ForOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::ForOp op, mlir::PatternRewriter& rewriter) const override{
        auto lbv = op.getLb();
        auto hbv = op.getHb();
        auto stepv = op.getStep();
        rewriter.setInsertionPoint(op);
        auto loc = op.getLoc();
        auto &region = op.getBody();
        // Rewrite for loop as scf.for
        auto buildFor = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange other){
            mlir::IRMapping irmap;
            irmap.map(region.getArgument(0), iter);
            for (auto &b: region.getOps()) {
                if (llvm::isa<isq::ast::YieldOp>(b)) {
                    break;
                }
                builder.clone(b, irmap);
            }
            builder.create<mlir::scf::YieldOp>(loc);
        };
        auto new_op = rewriter.create<mlir::scf::ForOp>(loc, lbv, hbv, stepv, std::nullopt, buildFor);
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

class RemoveTemplates : public mlir::OpRewritePattern<mlir::func::FuncOp>{
public:
    RemoveTemplates(mlir::MLIRContext *ctx): mlir::OpRewritePattern<mlir::func::FuncOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp op, mlir::PatternRewriter& rewriter) const override{
        if (op->hasAttr(TEMPLATE)) {
            rewriter.eraseOp(op);
        }
        return mlir::failure();
    }
};

class RewriteFunc : public mlir::OpRewritePattern<isq::ast::FuncOp>{
public:
    RewriteFunc(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::FuncOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::FuncOp op, mlir::PatternRewriter& rewriter) const override{
        rewriter.setInsertionPoint(op);
        auto con = rewriter.create<mlir::func::ConstantOp>(op.getLoc(), op.getType(), op.getValueAttr());
        op.getResult().replaceAllUsesWith(con.getResult());
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

class UpdateDecorate : public mlir::OpRewritePattern<isq::ast::DecorateOp>{
public:
    UpdateDecorate(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::DecorateOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::DecorateOp op, mlir::PatternRewriter& rewriter) const override{
        auto gate = op.getGate();
        gate.setType(ReplaceGateType(rewriter.getContext(), gate.getType()));
        rewriter.setInsertionPoint(op);
        auto dec = rewriter.create<DecorateOp>(op.getLoc(), op.getType(), gate, op.getAdjointAttr(), op.getCtrlAttr());
        op.getResult().replaceAllUsesWith(dec.getResult());
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

class UpdateDefGate : public mlir::OpRewritePattern<isq::ast::DefgateOp>{
public:
    UpdateDefGate(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::DefgateOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::DefgateOp op, mlir::PatternRewriter& rewriter) const override{
        rewriter.setInsertionPoint(op);
        auto ctx = rewriter.getContext();
        auto ty = mlir::TypeAttr::get(ReplaceGateType(ctx, op.getType()));
        rewriter.create<DefgateOp>(op.getLoc(), ty, op.getSymNameAttr(), mlir::StringAttr::get(ctx, "nested"), mlir::ArrayAttr{}, op.getDefinitionAttr(), op.getParameters());
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

class UpdateApplyGPhase : public mlir::OpRewritePattern<isq::ast::ApplyGPhase>{
public:
    UpdateApplyGPhase(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::ApplyGPhase>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::ApplyGPhase op, mlir::PatternRewriter& rewriter) const override{
        mlir::Value gate = op.getGate();
        gate.setType(ReplaceGateType(rewriter.getContext(), gate.getType()));
        rewriter.create<ApplyGPhase>(op.getLoc(), gate);
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

class UpdateApplyGate : public mlir::OpRewritePattern<isq::ast::ApplyGateOp>{
public:
    UpdateApplyGate(mlir::MLIRContext *ctx): mlir::OpRewritePattern<isq::ast::ApplyGateOp>(ctx, 1){
    }
    mlir::LogicalResult matchAndRewrite(isq::ast::ApplyGateOp op, mlir::PatternRewriter& rewriter) const override{
        mlir::Value gate = op.getGate();
        gate.setType(ReplaceGateType(rewriter.getContext(), gate.getType()));
        mlir::SmallVector<mlir::Type> ty;
        for (mlir::OpResult r : op.getR()) {
            ty.push_back(r.getType());
        }
        auto apply = rewriter.create<ApplyGateOp>(op.getLoc(), ty, gate, op.getArgs());
        for (int i=0; i<op.getNumResults(); i++) {
            op.getResult(i).replaceAllUsesWith(apply.getResult(i));
        }
        rewriter.eraseOp(op);
        return mlir::failure();
    }
};

struct InstantiateTemplatePass : public mlir::PassWrapper<InstantiateTemplatePass, mlir::OperationPass<mlir::ModuleOp>>{
    void runOnOperation() override{
        mlir::ModuleOp m = this->getOperation();
        auto ctx = m->getContext();
        mlir::GreedyRewriteConfig config;
        //config.maxIterations = mlir::GreedyRewriteConfig::kNoLimit;
        config.maxIterations = 20;
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<PrepareInstantiated>(ctx, m);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps, config);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<UnrollFor>(ctx);
            rps.add<RewriteLet>(ctx);
            rps.add<RewriteLetrec>(ctx);
            rps.add<RewriteLogic>(ctx, m);
            mlir::arith::AddIOp::getCanonicalizationPatterns(rps, ctx);
            mlir::arith::SubIOp::getCanonicalizationPatterns(rps, ctx);
            mlir::arith::MulIOp::getCanonicalizationPatterns(rps, ctx);
            mlir::arith::DivSIOp::getCanonicalizationPatterns(rps, ctx);
            mlir::arith::CeilDivSIOp::getCanonicalizationPatterns(rps, ctx);
            mlir::math::IPowIOp::getCanonicalizationPatterns(rps, ctx);
            rps.add<RewriteTempl>(ctx, m);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps, config);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<RemoveTemplates>(ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps, config);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<RewriteCall>(ctx, m);
            rps.add<RewriteFor>(ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps, config);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<RewriteFunc>(ctx);
            rps.add<UpdateDefGate>(ctx);
            rps.add<UpdateDecorate>(ctx);
            rps.add<UpdateApplyGate>(ctx);
            rps.add<UpdateApplyGPhase>(ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps, config);
        } while(0);
    }
    mlir::StringRef getArgument() const final{
        return "instantiate-template";
    }
    mlir::StringRef getDescription() const final{
        return "Instantiate template function with given arguments.";
    }
};

void registerInstantiateTemplate(){
    mlir::PassRegistration<InstantiateTemplatePass>();
}

}
}
}