#! /usr/bin/env python

import datetime
import numpy as np
import matplotlib
matplotlib.use('TkAgg')

import matplotlib.pyplot as plt; plt.rcdefaults()
import matplotlib.pyplot as plt

from subprocess import Popen, PIPE, STDOUT
import random
import sys
import statistics

MAX_PIECES = 48
MAX_HOUSES = 14

# Setable variables with default values
DEPTH = 9
TESTS = 15
ITERATIONS = 5
OUTFILE = f'depth_{DEPTH}_tests_{TESTS}_iter_{ITERATIONS}_{datetime.datetime.now()}_benchmark.png'

PROPS = ['cycles', 'instructions', 'branches', 'branch-misses', 'time-elapsed']

def generate_board():
    available = MAX_PIECES
    board = []
    for i in range(13):
        pool = available if available < 10 else available // 3
        value = random.choice(list(range(0,pool))) # little biased to distribute more
        board.append(value)
        available = available - value
    board.append(available)
    assert(sum(board) == MAX_PIECES) # not less or more than max pieces
    assert(len(board) == MAX_HOUSES) # not less or more than all house
    return board

def test(program, board):
    command = 'LC_NUMERIC=en_US.utf8 perf stat -B -e cycles:u -e instructions:u -e branches:u -e branch-misses:u'.split(' ') + [program] + [str(DEPTH)] + list(map(str, board))
    process = Popen(' '.join(command), stdout=PIPE, stderr=STDOUT, shell=True)

    # Return values
    program_output = ''
    cycles = None
    instructions = None
    branches = None
    branch_misses = None
    total_time = None

    current_line = 0
    for line in process.stdout.readlines():
        line = line.decode('UTF-8')

        if current_line <= 5:
            program_output += line

        current_line += 1

        if 'cycles:u' in line:
            cycles = int(line.split()[0].replace(',', ''))
        if 'instructions:u' in line:
            instructions = int(line.split()[0].replace(',', ''))
        if 'branches:u' in line:
            branches = int(line.split()[0].replace(',', ''))
        if 'branch-misses:u' in line:
            branch_misses = int(line.split()[0].replace(',', ''))
        if 'time elapsed' in line:
            total_time = float(line.split()[0])

    return program_output, cycles, instructions, branches, branch_misses, total_time

def plot_bench(bench1_name, bench1, bench2_name, bench2):
    y_pos = np.arange(len(PROPS))

    width = 0.35
    plt.bar(y_pos, bench1, width, alpha=0.5, label=bench1_name)
    plt.bar(y_pos + width, bench2, width, alpha=0.5, label=bench2_name)
    plt.xticks(y_pos + width / 2, PROPS)
    plt.ylabel('Cycles, Instructions per cycles, Number, Number')
    plt.yscale('log')
    plt.legend(loc='best')
    plt.title(f'Average of {TESTS} Instances each {ITERATIONS} with {DEPTH} Depth')

    plt.savefig(OUTFILE)

def main(program1, program2):
    print('Legend: cycles, instructions, branches, branch-misses, total-time\n')

    global_result1 = None
    global_result2 = None

    for i in range(TESTS):
        board = generate_board()
        results_1 = None
        results_2 = None

        for j in range(ITERATIONS):
            output1, *values1 = test(program1, board)
            output2, *values2 = test(program2, board)

            # Programs are not compareable if they produce differnet outputs
            if output1 != output2:
                print('Error: Both programs return different outputs')
                print(output1)
                print(output2)
                return

            if results_1 == None and results_2 == None:
                results_1 = [[a] for a in values1]
                results_2 = [[a] for a in values2]
            else:
                results_1 = [a + [b] for (a, b) in zip(results_1, values1)]
                results_2 = [a + [b] for (a, b) in zip(results_2, values2)]
        
        results_1 = list(map(statistics.mean, results_1))
        results_2 = list(map(statistics.mean, results_2))

        print('Test: ', str(i + 1), ':\n')
        print(program1, ': ', results_1)
        print(program2, ': ', results_2)
        print('\n')

        if global_result1 == None and global_result2 == None:
            global_result1 = results_1
            global_result2 = results_2
        else:
            global_result1 = [a + b for (a, b) in zip(global_result1, results_1)]
            global_result2 = [a + b for (a, b) in zip(global_result2, results_2)]

    print(program1, ': ', global_result1)
    print(program2, ': ', global_result2)
    print('\n')
    plot_bench(program1, global_result1, program2, global_result2)

if __name__ == '__main__' :    
    if len(sys.argv) < 3:
        print('USAGE:',sys.argv[0],'OWARE_PROGRAM1', 'OWARE_PROGRAM2', '[SEED]', '[TESTAMOUNT]', '[ITERATIONS]', '[DEPTH]', '[OUTPUT-FILE]')
        exit(1)
    if len(sys.argv) > 3:
        random.seed(int(sys.argv[3]))
    if len(sys.argv) > 4:
        TESTS = int(sys.argv[4])
    if len(sys.argv) > 5:
        ITERATIONS = int(sys.argv[5]) 
    if len(sys.argv) > 6:
        DEPTH = int(sys.argv[6])
    if len(sys.argv) > 7:
        OUTFILE = sys.argv[7]

    main(sys.argv[1], sys.argv[2])
