#include <iostream>
#include <unordered_set>
#include <llvm/ADT/SmallPtrSet.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>
#include <mlir/Pass/PassRegistry.h>
#include <mlir/Support/LLVM.h>
#include "isq/Operations.h"
#include "isq/OpVerifier.h"
#include "isq/QTypes.h"
#include "isq/passes/Passes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "isq/GateDefTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "isq/passes/Mem2Reg.h"
namespace isq{
namespace ir{
namespace passes{

const char* DERIVED_GATE = "derived_gate";
const char* INVERSE_SUFFIX = "__inv";

/*
* Generate the inverse version of derived gates.
*
* This rewriten pass assumes the input has the following pattern:
*   %arg = calculate angle
*   %gate = isq.use "Ry"(%arg)
*   %load = memref.load %alloc[0]
*   %store = isq.apply %gate, %load
*   memref.store %store, %alloc[0]
*   ...
*   isq.accumulate_gphase %alloc
*   memref.dealloc %alloc
*   return
*
* The operations above the UseOp will be copied to the beginning of the inverse function.
* The ReturnOp, AccumulateGPhase, and DeallocOp will be copied to the end.
* Each load-apply-store group will remain their sequence but the groups will be reversed.
* In addition, an adjoint decorator will be wrapped around each gate.
*/
class GenerateDerivingInverse : public mlir::OpRewritePattern<DefgateOp> {
    mlir::ModuleOp rootModule;
    std::unordered_set<std::string> &pureSymbols;
public:
    GenerateDerivingInverse(mlir::MLIRContext *ctx, mlir::ModuleOp module, std::unordered_set<std::string> &pureSymbols): mlir::OpRewritePattern<DefgateOp>(ctx, 1), rootModule(module), pureSymbols(pureSymbols) {
    }
    mlir::LogicalResult matchAndRewrite(DefgateOp op, mlir::PatternRewriter& rewriter) const override {
        if (op->hasAttr(DERIVED_GATE)) {
            return mlir::failure();
        }
        auto ctx = op->getContext();
        rewriter.startRootUpdate(op);
        op->setAttr(mlir::StringAttr::get(ctx, DERIVED_GATE), mlir::UnitAttr::get(ctx));
        rewriter.finalizeRootUpdate(op);

        auto def = op.getDefinition();
        if (def->size() != 1) {
            return mlir::failure();
        }
        mlir::ModuleOp rootModule = this->rootModule;
        mlir::func::FuncOp func = nullptr;

        auto d = AllGateDefs::parseGateDefinition(op, 0, op.getType(), *def->getAsRange<GateDefinition>().begin());
        if (d == std::nullopt) {
            return mlir::failure();
        }
        if (auto def = llvm::dyn_cast_or_null<DecompositionRawDefinition>(&**d)) {
            func = def->getDecomposedFunc();
        }
        if (!func) {
            return mlir::failure();
        }
        pureSymbols.insert(op.getSymName().str());
        rewriter.setInsertionPointToStart(rootModule.getBody());
        std::string inv_name = func.getSymName().str() + INVERSE_SUFFIX;
        auto inv = rewriter.create<mlir::func::FuncOp>(func.getLoc(), inv_name, func.getFunctionType(), rewriter.getStringAttr("public"), nullptr, nullptr);
        rewriter.startRootUpdate(inv);
        auto new_block = inv.addEntryBlock();
        rewriter.setInsertionPointToStart(new_block);

        mlir::DenseMap<mlir::Value, mlir::Value> map;
        auto &block = func.getBlocks().front();
        for (int i=0; i<block.getNumArguments(); i++) {
            map[block.getArgument(i)] = new_block->getArgument(i);
        }
        mlir::Operation *insert_point = nullptr;
        mlir::Operation *potential_insert_point = nullptr;
        bool applied = false;
        bool last_store = false;
        for (auto &op: func.getOps()) {
            if (llvm::isa<mlir::cf::BranchOp, mlir::cf::CondBranchOp>(op)) {
                continue;
            }
            if (!llvm::isa<mlir::affine::AffineStoreOp, mlir::memref::StoreOp>(op) && last_store) {
                if (applied) {
                    rewriter.setInsertionPoint(insert_point);
                    insert_point = nullptr;
                    applied = false;
                }
                potential_insert_point = nullptr;
            }
            mlir::SmallVector<mlir::Value> oprands;
            for (auto origin: op.getOperands()) {
                auto oprand = map.find(origin);
                assert(oprand != map.end());
                oprands.push_back(oprand->getSecond());
            }
            auto new_op = op.clone();
            new_op->setOperands(oprands);
            if (auto apply = llvm::dyn_cast_or_null<ApplyGateOp>(new_op)) {
                applied = true;
                insert_point = potential_insert_point;
                auto gate = apply.getGate();
                auto decorate = rewriter.create<DecorateOp>(apply.getLoc(), gate.getType(), gate, true, mlir::ArrayAttr::get(ctx, {}));
                apply.setOperand(0, decorate);
            }

            if (llvm::isa<mlir::func::ReturnOp, AccumulateGPhase, mlir::memref::DeallocOp, DeallocOp>(op)) {
                mlir::OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
                rewriter.setInsertionPointToEnd(new_block);
                rewriter.insert(new_op);
                rewriter.restoreInsertionPoint(ip);
            } else if (llvm::isa<mlir::memref::SubViewOp>(op) && insert_point) {
                mlir::OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
                rewriter.setInsertionPoint(insert_point);
                rewriter.insert(new_op);
                rewriter.restoreInsertionPoint(ip);
            } else {
                rewriter.insert(new_op);
            }

            for (int i=0; i<op.getNumResults(); i++) {
                map[op.getResult(i)] = new_op->getResult(i);
            }
            if (auto load = llvm::dyn_cast_or_null<mlir::affine::AffineLoadOp>(op);
                load && load.getType() == QStateType::get(ctx) && potential_insert_point == nullptr) {
                potential_insert_point = new_op;
            }
            else if (auto load = llvm::dyn_cast_or_null<mlir::memref::LoadOp>(op);
                load && load.getType() == QStateType::get(ctx) && potential_insert_point == nullptr) {
                potential_insert_point = new_op;
            }
            last_store = llvm::isa<mlir::affine::AffineStoreOp, mlir::memref::StoreOp>(op);
        }
        rewriter.finalizeRootUpdate(inv);
        auto new_def = llvm::cast_or_null<DefgateOp>(op->clone());
        rewriter.setInsertionPointToStart(rootModule.getBody());
        rewriter.insert(new_def);
        rewriter.startRootUpdate(new_def);
        auto name = new_def.getSymName().str() + INVERSE_SUFFIX;
        new_def.setSymName(name);

        auto decomp = GateDefinitionAttr::get(ctx, 
            mlir::StringAttr::get(ctx, DecompositionRawDefinition::defKindName()), 
            mlir::SymbolRefAttr::get(ctx, inv_name));
        new_def.setDefinitionAttr(mlir::ArrayAttr::get(ctx, {decomp}));
        rewriter.finalizeRootUpdate(new_def);
        return mlir::success();
    }
};

// Remove the DERIVED_GATE attribute
class DerivedGatedefCleanup : public mlir::OpRewritePattern<DefgateOp>{
public:
    DerivedGatedefCleanup(mlir::MLIRContext* ctx): mlir::OpRewritePattern<DefgateOp>(ctx, 1) {
    }

    mlir::LogicalResult matchAndRewrite(DefgateOp op, mlir::PatternRewriter& rewriter) const override {
        // Check for derived-gate notation.
        if (!op->hasAttr(DERIVED_GATE)) {
            return mlir::failure();
        }
        auto ctx = op->getContext();
        rewriter.startRootUpdate(op);
        op->removeAttr(DERIVED_GATE);
        rewriter.finalizeRootUpdate(op);
        return mlir::success();
    }
};

/*
* Change `inv fun(...)` into `fun_inv(...)`.
*/
class ReplaceUseGateWithInv : public mlir::OpRewritePattern<DecorateOp>{
    std::unordered_set<std::string> &pureSymbols;
public:
    ReplaceUseGateWithInv(mlir::MLIRContext* ctx, std::unordered_set<std::string> &pureSymbols): mlir::OpRewritePattern<DecorateOp>(ctx, 1), pureSymbols(pureSymbols) {
    }

    mlir::LogicalResult matchAndRewrite(DecorateOp op, mlir::PatternRewriter& rewriter) const override {
        if (!op.getAdjoint()) {
            return mlir::failure();
        }

        // Assume there is a single level of DecorateOp
        auto use = llvm::dyn_cast_or_null<UseGateOp>(op.getArgs().getDefiningOp());
        assert(use);

        // Verify that the gate is derived from a user-defined procedure
        mlir::StringAttr name = use.getName().getLeafReference();
        if (!pureSymbols.contains(name.str())) {
            return mlir::failure();
        }

        // Create a new UseGate with mangled name
        auto use_inv = use.clone();
        use_inv.setNameAttr(mlir::SymbolRefAttr::get(op->getContext(), name.str() + INVERSE_SUFFIX));
        rewriter.setInsertionPoint(op);
        rewriter.insert(use_inv);

        // Remove adjoint from the original op
        rewriter.startRootUpdate(op);
        op.setOperand(use_inv);
        op.setAdjoint(false);
        rewriter.finalizeRootUpdate(op);
        return mlir::success();
    }
};

struct GenerateDerivingInversePass : public mlir::PassWrapper<GenerateDerivingInversePass, mlir::OperationPass<mlir::ModuleOp>>{
    void runOnOperation() override{
        std::unordered_set<std::string> pureFuncs;
        mlir::ModuleOp m = this->getOperation();
        auto ctx = m->getContext();
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<GenerateDerivingInverse>(ctx, m, pureFuncs);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);
        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<DerivedGatedefCleanup>(ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);

        // Canonicalize DecorateOp so that multi-level DecorateOps can be merged into one
        do {
            mlir::RewritePatternSet rps(ctx);
            DecorateOp::getCanonicalizationPatterns(rps, ctx);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);

        do {
            mlir::RewritePatternSet rps(ctx);
            rps.add<ReplaceUseGateWithInv>(ctx, pureFuncs);
            mlir::FrozenRewritePatternSet frps(std::move(rps));
            (void)mlir::applyPatternsAndFoldGreedily(m.getOperation(), frps);
        } while(0);
    }
    mlir::StringRef getArgument() const final{
        return "isq-gen-deriving-inverse";
    }
    mlir::StringRef getDescription() const final{
        return "Generate the inverse for the gates that are derived from procedures.";
    }
};

void registerGenerateDerivingInverse(){
    mlir::PassRegistration<GenerateDerivingInversePass>();
}

}
}
}