#include "isq/Dialect.h"
#include "isq/Operations.h"
#include "isq/GateDefTypes.h"
#include "isq/QSynthesis.h"
#include "Utils.h"

namespace isq{
namespace ir{
/*
* Extract the value of a complex number created by complex::CreateOp.
*
* The input real and imag parts must be defined by arith::ConstantOp
*/
std::pair<double, double> getValueFromComplexCreateOp(mlir::complex::CreateOp create) {
    assert(create && "The coefficient is not canonicalized!");
    auto real_op = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(create.getReal().getDefiningOp());
    assert(real_op && "The real part of the coefficient is not canonicalized!");
    auto real_attr = real_op.getValue().dyn_cast<mlir::FloatAttr>();
    double real = real_attr.getValue().convertToDouble();
    auto imag_op = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(create.getImaginary().getDefiningOp());
    assert(imag_op && "The real part of the coefficient is not canonicalized!");
    auto imag_attr = imag_op.getValue().dyn_cast<mlir::FloatAttr>();
    double imag = imag_attr.getValue().convertToDouble();
    return {real, imag};
}

// Get the amplitude of ket expressions recursively
bool getAmplitude(mlir::Value value, int pre, llvm::SmallVector<Eigen::dcomplex> &amplitude){
    mlir::Operation *operation = value.getDefiningOp();
    if (auto op = llvm::dyn_cast_or_null<isq::ir::KetOp>(operation)) {
        auto create = llvm::dyn_cast_or_null<mlir::complex::CreateOp>(op.getCoeff().getDefiningOp());
        std::pair<double, double> value = getValueFromComplexCreateOp(create);
        auto basis = op.getBasis();
        if (basis >= amplitude.size()) {
            op.emitError("The basis value is not within the Hilbert space!");
            return false;
        }
        amplitude[basis] += Eigen::dcomplex(pre * value.first, pre * value.second);
        return true;
    } else if (auto vec = llvm::dyn_cast_or_null<isq::ir::VecOp>(operation)) {
        llvm::SmallVector<Eigen::dcomplex> ll = vec.getVec().toMatrixVal()[0];
        int size = ll.size();
        if (size > amplitude.size()) {
            vec.emitOpError("The vector is too long!");
            return false;
        }
        for (int i=0; i<size; i++) {
            amplitude[i] = ll[i];
        }
        return true;
    } else if (auto op = llvm::dyn_cast_or_null<isq::ir::AddOp>(operation)) {
        if (!getAmplitude(op.getLhs(), pre, amplitude)) {
            return false;
        }
        return getAmplitude(op.getRhs(), pre, amplitude);
    } else if (auto op = llvm::dyn_cast_or_null<isq::ir::SubOp>(operation)) {
        if (!getAmplitude(op.getLhs(), pre, amplitude)) {
            return false;
        }
        return getAmplitude(op.getRhs(), -pre, amplitude);
    } else {
        op.emitError("Unexpected operation!");
        return false;
    }
}
}
}