#include "isq/GateDefTypes.h"
#include "isq/Operations.h"
#include "isq/QTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "llvm/ADT/APFloat.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "isq/passes/Passes.h"
#include "isq/oracle/QM.h"
#include <iostream>

namespace isq{
namespace ir{
namespace passes{

using namespace qm;

class PermutationDef: public mlir::OpRewritePattern<DefgateOp>{
    mlir::ModuleOp rootModule;

public:
    PermutationDef(mlir::MLIRContext* ctx, mlir::ModuleOp module): mlir::OpRewritePattern<DefgateOp>(ctx, 1), rootModule(module){
    }

    mlir::LogicalResult decomposePermutation(mlir::PatternRewriter& rewriter, ::mlir::StringRef decomposed_name, const std::vector<int>& value) const{
        
        auto n = (int) std::log2(value.size());
        //std::cout << n << std::endl;
        auto rootModule = this->rootModule;
        
        mlir::PatternRewriter::InsertionGuard guard(rewriter);
        rewriter.setInsertionPointToStart(rootModule.getBody());
        mlir::SmallVector<mlir::Type> qs;
        for(auto i=0; i<n; i++){
            qs.push_back(QStateType::get(rewriter.getContext()));
        }
        auto funcop = mlir::func::FuncOp::create(::mlir::UnknownLoc::get(rewriter.getContext()), decomposed_name, mlir::FunctionType::get(rewriter.getContext(), qs, qs));
        auto ctx = rewriter.getContext();
        //funcop.setSymVisibilityAttr(mlir::StringAttr::get(ctx, "private"));
        rewriter.insert(funcop.getOperation());
        auto entry_block = funcop.addEntryBlock();
        rewriter.setInsertionPointToStart(entry_block);
        mlir::SmallVector<mlir::Value> qubits;
        qubits.append(entry_block->args_begin(), entry_block->args_end());

        auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(::mlir::UnknownLoc::get(ctx), 0);
        auto qst = QStateType::get(rewriter.getContext());

        if (n == 1){
            if (value[0] == 1){
                auto use = rewriter.create<isq::ir::UseGateOp>(
                    ::mlir::UnknownLoc::get(ctx),
                    isq::ir::GateType::get(ctx, 1, GateTrait::General),::mlir::FlatSymbolRefAttr::get(ctx, "X"), ::mlir::ValueRange{}
                );
                auto apply = rewriter.create<ApplyGateOp>(
                    ::mlir::UnknownLoc::get(ctx),
                    ::mlir::ArrayRef<::mlir::Type>{qst},
                    use.getResult(),
                    ::mlir::ArrayRef{qubits[0]}
                );
                qubits[0]=apply.getResult(0);
            }
        }else{
            auto d = do_permutation(value, n);
            int target = 0;
            int dt = 1;
            for (auto v : d){
                if (v.size() > 0){
                    vector<int> pos;
                    for (int i = 0; i < n; i++){
                        if (i == target) continue;
                        pos.push_back(i);
                    }
                    auto myqm = QM(n-1);
                    auto vnodes = myqm.simplify(v);
                    auto opt = myqm.optimize(vnodes);
                    for (auto bit: opt){
                        // get control qbit idx
                        vector<int> qidx;
                        mlir::SmallVector<mlir::Attribute> ctrls;
                        for (int i = 0; i < n-1; i++){
                            if (bit[i] == '-') continue;
                            qidx.push_back(i);
                            ctrls.push_back(mlir::BoolAttr::get(ctx, bit[i] == '1'));
                        }
                        int num = qidx.size();
                        // get use gate
                        auto use = rewriter.create<isq::ir::UseGateOp>(
                            ::mlir::UnknownLoc::get(ctx), 
                            isq::ir::GateType::get(ctx, 1, GateTrait::General),
                            ::mlir::FlatSymbolRefAttr::get(ctx, "X"),
                            ::mlir::ValueRange{}
                        );
                        mlir::Value res = use.getResult();
                        // get decorate
                        if (num > 0){
                            auto decorate = rewriter.create<isq::ir::DecorateOp>(
                                ::mlir::UnknownLoc::get(ctx),
                                isq::ir::GateType::get(ctx, 1+num, GateTrait::General),
                                use.getResult(),
                                false,
                                mlir::ArrayAttr::get(ctx, mlir::ArrayRef<mlir::Attribute>(ctrls))
                            );
                            res = decorate.getResult();
                        }
                        // qubit load
                        mlir::SmallVector<mlir::Value> params;
                        mlir::SmallVector<mlir::Type> types;
                        for (auto q: qidx){
                            params.push_back(qubits[pos[q]]);
                            types.push_back(qst);
                        }
                        params.push_back(qubits[target]);
                        types.push_back(qst);
                        // isq apply
                        auto apply = rewriter.create<isq::ir::ApplyGateOp>(
                            ::mlir::UnknownLoc::get(ctx), 
                            mlir::ArrayRef<mlir::Type>(types), 
                            res, 
                            mlir::ArrayRef<mlir::Value>(params)
                        );
                        // qubit store
                        int ridx = 0;
                        for (auto q: qidx){
                            qubits[pos[q]] = apply.getResult(ridx);
                            ridx += 1;
                        }
                        qubits[target] = apply.getResult(ridx);
                    }
                }

                if (target == n-1) dt = -1;
                target += dt;
            }
        }

        rewriter.create<mlir::func::ReturnOp>(::mlir::UnknownLoc::get(rewriter.getContext()), qubits);
        return mlir::success();
    }

    mlir::LogicalResult matchAndRewrite(isq::ir::DefgateOp defgate,  mlir::PatternRewriter &rewriter) const override{
        if (!defgate.getDefinition()) return mlir::failure();
        //if (defgate.getDefinition()->size() != 2) return mlir::failure();
        int id = 0;
        for (auto def: defgate.getDefinition()->getAsRange<GateDefinition>()){
            auto d = AllGateDefs::parseGateDefinition(defgate, id, defgate.getType(), def);
            if (d == std::nullopt) return mlir::failure();
            if (auto per = llvm::dyn_cast_or_null<PermutationDefinition>(&**d)){
                auto per_decomp_name = std::string(defgate.getSymName())+"__ISQ_GATEDEF";
                // construct new matrix name.
                auto per_decomp_sym = mlir::FlatSymbolRefAttr::get(mlir::StringAttr::get(rewriter.getContext(), per_decomp_name));
                auto per_decomp = mlir::SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(defgate, per_decomp_sym);
                auto& per_data = per->getValue();
                auto ctx = rewriter.getContext();
                if(!per_decomp){
                    if(mlir::failed(decomposePermutation(rewriter, per_decomp_name, per_data))){
                        return mlir::failure();
                    }
                    rewriter.updateRootInPlace(defgate, [&]{
                        auto defs = *defgate.getDefinition();
                        ::mlir::SmallVector<::mlir::Attribute> new_defs;
                        auto r = defs.getAsRange<::mlir::Attribute>();
                        new_defs.append(r.begin(), r.end());
                        new_defs.push_back(GateDefinition::get(
                            ctx,
                            ::mlir::StringAttr::get(ctx, "decomposition"),
                            per_decomp_sym
                        ));
                        defgate->setAttr("definition", ::mlir::ArrayAttr::get(ctx, new_defs));
                    });
                }
            }
        }
        return mlir::success();
    }
};


struct PermutationDecomposePass: public mlir::PassWrapper<PermutationDecomposePass, mlir::OperationPass<mlir::ModuleOp>>{

    void runOnOperation() override{
        mlir::ModuleOp m = this->getOperation();
        auto ctx = m->getContext();
        
        mlir::RewritePatternSet rps(ctx);
        rps.add<PermutationDef>(ctx, m);
        mlir::FrozenRewritePatternSet frps(std::move(rps));
        (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
    }

    mlir::StringRef getArgument() const final{
        return "isq-permutation-decompose";
    }

    mlir::StringRef getDescription() const final{
        return "Using Young Subgraph to synthesize circuit for permutation.";
    }
};

void registerPermutationDecompose(){
    mlir::PassRegistration<PermutationDecomposePass>();
}

}

}
}
