/*! @file random-sort.cpp
 *  Compare list shuffles using a number of different methods
 */

#include <iostream>
#include <vector>
#include <random>
#include <chrono>
#include <map>
#include <algorithm>
#include <functional>
#include <fstream>
#include <sstream>

std::mt19937 generator;

typedef std::map<std::pair<int, int>, int> Positions;

struct Result {
    int len;
    std::chrono::duration<float, std::micro> time;
    double chi_2;
    Positions pos;
    float expected;
};

typedef std::function<void(std::vector<int>&)> SortFunc;

bool coin_flip(int, int)
{
    return generator() & 1;
}

void randomsort_std(std::vector<int>& values)
{
    std::stable_sort(values.begin(), values.end(), coin_flip);
}

template<typename T> void random_merge_sort(typename T::iterator begin, typename T::iterator end)
{
    if (end - begin < 2) {
        return;
    }

    std::vector<int>::iterator mid = begin + (end - begin) / 2;
    random_merge_sort<T>(begin, mid);
    random_merge_sort<T>(mid, end);
    std::inplace_merge(begin, mid, end, coin_flip);
}

void randomsort_merge(std::vector<int>& values)
{
    random_merge_sort<std::vector<int>>(values.begin(), values.end());
}

void randomsort_bubble(std::vector<int>& values)
{
    for (auto i = values.begin(); i != values.end(); ++i) {
        for (auto j = values.begin(); j != i; ++j) {
            if (coin_flip(*i, *j)) {
                std::swap(*i, *j);
            }
        }
    }
}

//! Fisher-Yates
void pick(std::vector<int>& values)
{
    for (unsigned int i = 0; i < values.size(); i++) {
        std::swap(values[i], values[i + generator() % (values.size() - i)]);
    }
}

//! Random prefix sorting
void prefix(std::vector<int>& values)
{
    std::vector<std::pair<int, int>> prefixed;
    prefixed.reserve(values.size());
    for (auto v : values) {
        prefixed.emplace_back(generator(), v);
    }
    std::sort(prefixed.begin(), prefixed.end());
    for (unsigned int i = 0; i < prefixed.size(); i++) {
        values[i] = prefixed[i].second;
    }
}

Result test(SortFunc func, int len, unsigned int trials)
{
    Result out;
    out.len = len;
    out.expected = double(trials) / len;

    double time = 0;

    std::chrono::high_resolution_clock monoclock;
    for (unsigned int k = 0; k < trials; ++k) {
        std::vector<int> values(len);

        for (int i = 0; i < len; ++i) {
            values[i] = i;
        }

        auto start = monoclock.now();
        func(values);
        auto finish = monoclock.now();
        out.time += (finish - start) / trials;

        for (size_t n = 0; n < values.size(); ++n) {
            ++out.pos[std::make_pair(n, values[n])];
        }
    }

    out.chi_2 = 0;
    for (int i = 0; i < len; i++) {
        for (int j = 0; j < len; j++) {
            int count = out.pos[std::make_pair(i, j)];
            double delta = count - out.expected;
            out.chi_2 += delta * delta / out.expected / out.expected;
        }
    }
    out.chi_2 /= len * len;

    return out;
}

int main(void)
{
    std::vector<std::pair<std::string, SortFunc> > tests {
        {"pick", pick},
        {"prefix", prefix},
        {"sort_bubble", randomsort_bubble},
        {"sort_std", randomsort_std},
        {"sort_merge", randomsort_merge},
    };

#if 0
    std::cout << "| Algorithm  | N | Time (&mu;s/N) | $$\\chi^2$$ |   |" << std::endl;
    std::cout << "|------------|---|---------------:|------------:|---|" << std::endl;
    for (auto t : tests) {
        bool first = true;
        for (int size = 4; size < 250; size = size * 5 / 4) {
            auto r = test(t.second, size, 100000);

            if (first) {
                std::cout << "| " << t.first << " | ";
                first = false;
            } else {
                std::cout << "|| ";
            }

            std::stringstream fname;
            fname << t.first << '-' << size;

            std::cout << size << " | " << r.time.count() / size << " | " << r.chi_2
                      << " | ![{16,16,force_width=16,force_height=16,div_class=None}]("
                      << fname.str() + ".png \"" + t.first << ", N=" << size << "\")"
                      << " |" << std::endl;

            std::ofstream image(fname.str() + ".pnm");
            image << "P6" << std::endl << r.len << ' ' << r.len << ' ' << 255 << std::endl;
            for (int i = 0; i < r.len; i++) {
                for (int j = 0; j < r.len; j++) {
                    int count = r.pos[std::make_pair(i, j)];
                    double d = (count - r.expected) / r.expected;
                    double r = std::max(0., std::min(1., d));
                    double b = std::max(0., std::min(1., -d));
                    double g = std::max(0., 1 - r - b);
                    image.put(char(r * 255));
                    image.put(char(g * 255));
                    image.put(char(b * 255));
                }
            }
        }
    }
#else
    std::cout << "# n";
    for (auto t : tests) {
        std::cout << '\t' << t.first;
    }
    std::cout << std::endl;
    for (int i = 4; i < 512; i++) {
        std::cout << i << '\t';
        for (auto t : tests) {
            auto r = test(t.second, i, 100000);
            std::cout << '\t' << r.chi_2;
        }
        std::cout << std::endl;
    }
#endif
    return 0;
}
