#!/usr/bin/env python3
from typing import Union, List, Tuple

from scipy import stats
import numpy as np
import os
import argparse
import matplotlib.pyplot as plt

THRESHOLD = 0.05

COMMITS = [
    ('6ac0d27', 7),
    ('1bbf978', 8),
    ('e31a477', 11),
    ('8093b18', 11),
    ('021b72b', 14),
    ('9a3ed05', 14),
    ('7244e11', 14),
    ('9de3efa', 14),
    ('2429bed', 14),
    ('2732d73', 14),
    ('f0adf9f', 14),
    ('f890895', 14),
    ('2d90970', 15),
    ('46097e7', 15),
    ('37a2f8b', 15)]


def parse_file(file: Union[argparse.FileType, str]) -> np.ndarray:
    try:
        timing_strings = file.readlines()
    except AttributeError:
        with open(file, "r") as infile:
            timing_strings = infile.readlines()
    return np.array(list(map(lambda x: float(x.strip().replace(",", ".")), timing_strings)))


def boxplot(x1, x2, t1 = None, t2 = None):
    fig, axes = plt.subplots(1, 2, figsize=(9,4), sharey=True)

    bplot1 = axes[0].boxplot(x1, notch=True, vert=True, patch_artist=True)
    if t1 is not None:
        axes[0].set_title(t1)

    bplot2 = axes[1].boxplot(x2, notch=True, vert=True, patch_artist=True)
    if t2 is not None:
        axes[0].set_title(t2)

    colors = ["y", "c", "royalblue"]
    for bplot in (bplot1, bplot2):
        for patch, color in zip(bplot["boxes"], colors):
            patch.set_facecolor(color)

    plt.tight_layout()
    return plt


def main():
    args = parse_args()
    args.func(args)


def significance(args):
    x1 = parse_file(args.x1)
    x2 = parse_file(args.x2)

    print("Unpaired t-test for H0 = x1 and x1 have the same distribution")
    t_result = stats.ttest_rel(x1, x2)
    print("p-value: {}".format(t_result.pvalue))
    if (t_result.pvalue < THRESHOLD):
        print("The difference between x1 and x2 is statistically significant")
    else:
        print("No significant difference between x1 and x2")

    plot = boxplot(x1, x2)
    plot.show()


def generate_filename(result_dir, commit_hash, size):
    return os.path.join(result_dir, f'{commit_hash}_{size}.csv')


def generate_result_pairs(result_dir) -> List[Tuple[str, str]]:
    result_pairs = []
    prev_hash, prev_size = COMMITS[0]
    for (commit_hash, size) in COMMITS[1:]:
        result_pairs.append((generate_filename(result_dir, prev_hash, prev_size),
                             generate_filename(result_dir, commit_hash, prev_size)))
        prev_hash = commit_hash
        prev_size = size
    return result_pairs


def speedup(args):
    result_pairs = generate_result_pairs(args.result_dir)
    total_speedup = 1
    for (before_file, after_file) in result_pairs:
        before = np.median(parse_file(before_file))
        after = np.median(parse_file(after_file))
        speedup = before / after
        total_speedup *= speedup
        percent_increase = (1 - after/before) * 100
        print("Speedup from {:23s} to {:23s}: {:7.2f}  =   -{:8.1f} %".format(before_file, after_file, speedup, percent_increase))

    print(f"Total speedup is {total_speedup}")


def parse_args():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help='sub-command help')
    subparsers.required = True

    speedup_parser = subparsers.add_parser("speedup")
    speedup_parser.add_argument("result_dir")
    speedup_parser.set_defaults(func=speedup)

    significance_parser = subparsers.add_parser("significance")
    significance_parser.add_argument("x1", type=argparse.FileType("r"), help="File with 1 float per line")
    significance_parser.add_argument("x2", type=argparse.FileType("r"), help="File with 1 float per line")
    significance_parser.set_defaults(func=significance)

    return parser.parse_args()

if __name__ == '__main__':
    main()