import matplotlib.pyplot as plt

no_values = 40

gflops = [67.2 for i in range(no_values)]
bandwidth = [11.6352903 * i for i in range(no_values)]
fig = plt.figure()
gflops_new = [67.2 * 0.028125 for i in range(no_values)]
bandwidth_new = [11.4023097 * i for i in range(no_values)]

plt.plot(gflops, color='red', lw=1, label="peak performance (67.2 GFLOPS/s)")
plt.plot(bandwidth, color='blue', lw=1, label="peak STREAM bandwith (11.64 GB/s)")
plt.plot(gflops_new, color='xkcd:light red', lw=1, label="maximum reachable performance of the code (1.98 GFLOPS/s)")
plt.plot(bandwidth_new, color='xkcd:light blue', lw=1, label="peak STREAM bandwith (11.40 GB/s)")

x_labels_locations = [1/8 * pow(2, i) for i in range(10)]
x_labels = ["1/8", "1/4", "1/2", "1", "2", "4", "8", "16", "32", "64", "128", "256", "512"]

y_labels = [pow(2, i) for i in range(13)]
y_label_locations = list(range(len(y_labels)))

test_ai = [3.6562866,
        5.32997116507857,
        3.82922662688575,
        3.90582489675034]

test_gflops = [1.8281433,
        1.9304232,
        1.9302247,
        1.9296322]


plt.scatter(test_ai, test_gflops)

tests = ["100", "204800", "819200", "kdd_cup"]
for i, measurement in enumerate(zip(test_ai, test_gflops)):
    plt.annotate(tests[i], xy=measurement,  textcoords='data')

#plt.ylim(1, 300)
plt.yscale("log")
plt.xscale("log")
plt.xticks(x_labels_locations, x_labels)
plt.yticks(y_labels, y_labels)
plt.xlabel("AI (FLOPs/Byte)")
plt.ylabel("Attainable GFLOPs/s")
plt.legend(loc='upper center')
plt.title("Roofline diagram")
plt.savefig("../graphics/rl-1-u-m.png")
