#include <iostream>
#include <unordered_set>
#include <llvm/ADT/SmallPtrSet.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>
#include <mlir/Pass/PassRegistry.h>
#include <mlir/Support/LLVM.h>
#include "isq/Operations.h"
#include "isq/OpVerifier.h"
#include "isq/QTypes.h"
#include "isq/passes/Passes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "isq/GateDefTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "isq/passes/Mem2Reg.h"
namespace isq{
namespace ir{
namespace passes{

const char* INLINED = "inlined";

/*
* Inline all derived gates
*/
class InlineDeriving : public mlir::OpRewritePattern<DefgateOp> {
    std::unordered_set<std::string> &pureSymbols;
public:
    InlineDeriving(mlir::MLIRContext *ctx, std::unordered_set<std::string> &pureSymbols): mlir::OpRewritePattern<DefgateOp>(ctx, 1), pureSymbols(pureSymbols) {
    }
    mlir::LogicalResult matchAndRewrite(DefgateOp op, mlir::PatternRewriter& rewriter) const override {
        if (op->hasAttr(INLINED)) {
            return mlir::failure();
        }
        auto ctx = op->getContext();
        rewriter.startRootUpdate(op);
        op->setAttr(mlir::StringAttr::get(ctx, INLINED), mlir::UnitAttr::get(ctx));
        rewriter.finalizeRootUpdate(op);

        std::function<mlir::func::FuncOp(DefgateOp)> find_decompose = [&](DefgateOp op) -> mlir::func::FuncOp {
            auto defs = *op.getDefinition();
            auto id = 0;
            for (auto def: defs.getAsRange<GateDefinition>()) {
                auto d = AllGateDefs::parseGateDefinition(op, id, op.getType(), def);
                if (d == std::nullopt) {
                    llvm_unreachable("bad");
                }
                if (auto decomp = llvm::dyn_cast_or_null<DecompositionDefinition>(&**d)) {
                    return decomp->getDecomposedFunc();
                }
                id++;
            }
            return nullptr;
        };

        mlir::func::FuncOp func = find_decompose(op);
        if (!func) {
            return mlir::failure();
        }
        pureSymbols.insert(op.getSymName().str());

        rewriter.startRootUpdate(func);
        auto ops = func.getOps();
        auto it = ops.begin();
        while (it != ops.end()) {
            auto &op = *it++;
            if (auto apply = llvm::dyn_cast_or_null<ApplyGateOp>(op)) {
                UseGateOp use = llvm::dyn_cast_or_null<UseGateOp>(apply.getGate().getDefiningOp());
                if (!use) {
                    continue;
                }
                auto defgate = mlir::SymbolTable::lookupNearestSymbolFrom<DefgateOp>(use.getOperation(), use.getName());
                if (!defgate || !defgate.getDefinition()) {
                    continue;
                }

                // Recursively inline
                matchAndRewrite(defgate, rewriter);
                mlir::func::FuncOp sub = find_decompose(defgate);
                if (!sub) {
                    continue;
                }

                // Create a map from sub to fun
                mlir::DenseMap<mlir::Value, mlir::Value> map;
                auto pars = use.getParameters();
                int par_size = pars.size();
                for (int i=0; i<par_size; i++) {
                    map[sub.getArgument(i)] = pars[i];
                }
                auto args = apply.getArgs();
                for (int i=0; i<args.size(); i++) {
                    map[sub.getArgument(par_size + i)] = args[i];
                }

                rewriter.setInsertionPoint(apply);
                for (auto &op : sub.getOps()) {
                    if (auto ret = llvm::dyn_cast_or_null<mlir::func::ReturnOp>(op)) {
                        int nres = apply.getNumResults();
                        assert(nres == ret.getNumOperands());
                        for (int i=0; i<nres; i++) {
                            apply.getResult(i).replaceAllUsesWith(map[ret.getOperand(i)]);
                        }
                    } else {
                        mlir::SmallVector<mlir::Value> oprands;
                        for (auto origin: op.getOperands()) {
                            auto oprand = map.find(origin);
                            assert(oprand != map.end());
                            oprands.push_back(oprand->getSecond());
                        }
                        auto new_op = op.clone();
                        new_op->setOperands(oprands);
                        rewriter.insert(new_op);
                        for (int i=0; i<op.getNumResults(); i++) {
                            map[op.getResult(i)] = new_op->getResult(i);
                        }
                    }
                }
                rewriter.eraseOp(apply);
            }
        }
        rewriter.finalizeRootUpdate(func);
        return mlir::success();
    }
};

// Remove the INLINED attribute
class InlineDerivingCleanup : public mlir::OpRewritePattern<DefgateOp>{
public:
    InlineDerivingCleanup(mlir::MLIRContext* ctx): mlir::OpRewritePattern<DefgateOp>(ctx, 1) {
    }

    mlir::LogicalResult matchAndRewrite(DefgateOp op, mlir::PatternRewriter& rewriter) const override {
        // Check for INLINED notation.
        if (!op->hasAttr(INLINED)) {
            return mlir::failure();
        }
        auto ctx = op->getContext();
        rewriter.startRootUpdate(op);
        op->removeAttr(INLINED);
        rewriter.finalizeRootUpdate(op);
        return mlir::success();
    }
};

struct InlineDerivingPass : public mlir::PassWrapper<InlineDerivingPass, mlir::OperationPass<mlir::ModuleOp>>{
    void runOnOperation() override{
        std::unordered_set<std::string> pureFuncs;
        mlir::ModuleOp m = this->getOperation();
        auto ctx = m->getContext();
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<InlineDeriving>(ctx, pureFuncs);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<InlineDerivingCleanup>(ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);
    }
    mlir::StringRef getArgument() const final{
        return "isq-inline";
    }
    mlir::StringRef getDescription() const final{
        return "Inline the derived gates.";
    }
};

void registerInlineDerivingPass(){
    mlir::PassRegistration<InlineDerivingPass>();
}

}
}
}