#include "isq/oracle/QM.h"
#include <iostream>
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include <numeric>
#include <algorithm>

using std::string;
using std::set;
using std::vector;
using std::pair;
using std::map;
using std::make_pair;


int get_lower(vector<int> a, int i){
    while (i < a.size() and a[i] == 1) i += 1;
    return i;
}


set<int> get_diff(vector<int>& A, vector<int>& B, int m){
    
    int last = (1 << m) - 1;
    set<int> diff;
    for (int i = 0; i < A.size(); i++){
        if (A[i] != B[i]){
            diff.insert((A[i] >> (m+1) << m) + (A[i] & last));
        }
    }
    return diff;
}


vector<int> get_pair(vector<int>& A, int n, int m){
    
    map<int, int> w;
    vector<int> coil(1 << n);
    int mask = ((1 << n) -1) ^ (1 << m);
    int idx = 0;
    for (auto v : A){
        v &= mask;
        if (w.count(v) > 0){
            coil[idx] = w[v];
            coil[w[v]] = idx;
        }
        w[v] = idx;
        idx += 1;
    }
    return coil;
}

pair<vector<int>, vector<int>> next_ap_pair(vector<int>& A, vector<int>& P, int n, int m){
    
    vector<int> A_next(A);
    vector<int> P_next(P);
    int k = 1 << n;
    vector<int> bit_set(k);
    //two mask to set m bit 0 or 1
    int mask1 = (k -1) ^ (1 << m);
    int mask2 = 1 << m;

    auto A_pair = get_pair(A_next, n, m);
    auto P_pair = get_pair(P_next, n, m);

    int low_bit = 0;
    while (low_bit < k){
        int idx = low_bit;
        vector<int> coil;
        while (bit_set[idx] == 0){
            bit_set[idx] = 1;
            bit_set[A_pair[idx]] = 1;
            coil.push_back(idx);
            coil.push_back(A_pair[idx]);
            idx = P_pair[A_pair[idx]];
        }

        low_bit = get_lower(bit_set, low_bit);
        int b = (A_next[coil[0]] >> m) & 1;
        for (auto idx : coil){
            if (b == 0){
                A_next[idx] &= mask1;
                P_next[idx] &= mask1;
            }else{
                A_next[idx] |= mask2;
                P_next[idx] |= mask2;
            }
            b = 1-b;
        }           
    }
    return make_pair(A_next, P_next);
}



void print_info(vector<int>& A){
    for (auto v : A){
        std::cout << v << ' ';
    }
}

vector<set<int>> qm::do_permutation(vector<int> P, int N){

    /*
        reference paper 《YOUNG SUBGROUPS FOR REVERSIBILE COMPITERS》
    */

    // get A list and P list 
    vector<vector<int>> A_list;
    std::vector<int> A((1 << N));
    iota(A.begin(), A.end(), 0);
    A_list.push_back(A);
    vector<vector<int>> P_list;
    P_list.push_back(P);

    for (int i = 1; i < N; i++){
        auto ap = next_ap_pair(A_list.back(), P_list.back(), N, N - i);
        A_list.push_back(ap.first);
        P_list.push_back(ap.second);
    }
    // reverse P and merge to A
    reverse(P_list.begin(), P_list.end());
    A_list.insert(A_list.end(), P_list.begin(), P_list.end());
    // get exchange bit in adjacent list
    vector<int> M;
    for (int i = N - 1; i >= 0; i--) M.push_back(i);
    for (int i = 1; i < N; i++) M.push_back(i);
    
    vector<set<int>> ans;
    for (int i = 0; i < A_list.size() - 1; i++){
        auto diff = get_diff(A_list[i], A_list[i+1], M[i]);
        ans.push_back(diff);
    }
    return ans ;
}