import numpy as np
import matplotlib.pyplot as plt
import os

PROPS = ["cycles", "instructions", "branches", "branch-misses", "time-elapsed"]


def get_data(filename):
    with open(filename, "r") as data_file:
        cycles = None
        instructions = None
        branches = None
        branch_misses = None
        total_time = None

        for line in data_file:

            if "cycles:u" in line:
                cycles = int(line.split()[0].replace(",", ""))
            if "instructions:u" in line:
                instructions = int(line.split()[0].replace(",", ""))
            if "branches:u" in line:
                branches = int(line.split()[0].replace(",", ""))
            if "branch-misses:u" in line:
                branch_misses = int(line.split()[0].replace(",", ""))
            if "time elapsed" in line:
                total_time = float(line.split()[0])

        return cycles, instructions, branches, branch_misses, total_time


def plot_bench(bench1_name, bench1, bench2_name, bench2):
    y_pos = np.arange(5)

    width = 0.35
    plt.bar(y_pos, bench1, width, alpha=0.5, label=bench1_name)
    plt.bar(y_pos + width, bench2, width, alpha=0.5, label=bench2_name)
    plt.xticks(y_pos + width / 2, PROPS)
    plt.ylabel("Cycles, Instructions per cycles, Number, Number")
    plt.yscale("log")
    plt.legend(loc="best")
    plt.title(f"Average of 5/50 Instances with 13 Depth")

    plt.savefig("./graphics/{}-stats.png".format(bench2_name), dpi=200)
    plt.clf()


if __name__ == "__main__":
    files = os.listdir()
    basefile = "bench_results_comp_originial.txt"
    basedata = get_data(basefile)
    files = [file_ for file_ in files if file_.endswith(".txt") and file_ != "bench_results_comp_originial.txt"]
    comps = []
    names = []
    for filename in files:
        comp_data = get_data(filename)
        comps.append(comp_data)
        names.append(filename.split(".")[0].split("_", 3)[3])
        plot_bench("Original", basedata, filename.split(".")[0].split("_", 3)[3], comp_data)

    width = 0.35
    y_pos = np.arange(5) * 2.5
    plt.bar(y_pos, basedata, width, alpha=0.5, label="Original")
    for i in range(len(comps)):
        plt.bar(y_pos + (width * (i + 1)), comps[i], width, alpha=0.5, label=names[i])
    # plt.xticks(y_pos + width / 2, PROPS)
    plt.ylabel("Cycles, Instructions per cycles, Number, Number")
    plt.xticks(y_pos + width * 2.5, PROPS)
    plt.yscale("log")
    plt.legend(loc="best")
    plt.title(f"Average of 5/50 Instances with 13 Depth")

    plt.savefig("./graphics/collected-stats.png", dpi=200)
