#!/usr/bin/python3
# Madeleine Masser-Frye mmasserfrye@hmc.edu 5/22

from operator import index
import subprocess
import csv
import re
import matplotlib.pyplot as plt
import matplotlib.lines as lines
import numpy as np
from collections import namedtuple


def synthsfromcsv(filename):
    with open(filename, newline='') as csvfile:
        csvreader = csv.reader(csvfile)
        global allSynths
        allSynths = list(csvreader)
        for i in range(len(allSynths)):
            for j in range(len(allSynths[0])):
                try: allSynths[i][j] = int(allSynths[i][j])
                except: 
                    try: allSynths[i][j] = float(allSynths[i][j])
                    except: pass
            allSynths[i] = Synth(*allSynths[i])
    
def synthsintocsv(mod=None, width=None):
    ''' writes a CSV with one line for every available synthesis
        each line contains the module, tech, width, target freq, and resulting metrics
    '''
    specStr = ''
    if mod != None:
        specStr = mod
        if width != None:
            specStr += ('_'+str(width))
    specStr += '*'

    bashCommand = "grep 'Critical Path Length' runs/ppa_{}/reports/*qor*".format(specStr)
    outputCPL = subprocess.check_output(['bash','-c', bashCommand])
    linesCPL = outputCPL.decode("utf-8").split('\n')[:-1]

    bashCommand = "grep 'Design Area' runs/ppa_{}/reports/*qor*".format(specStr)
    outputDA = subprocess.check_output(['bash','-c', bashCommand])
    linesDA = outputDA.decode("utf-8").split('\n')[:-1]

    bashCommand = "grep '100' runs/ppa_{}/reports/*power*".format(specStr)
    outputP = subprocess.check_output(['bash','-c', bashCommand])
    linesP = outputP.decode("utf-8").split('\n')[:-1]

    cpl = re.compile('\d{1}\.\d{6}')
    f = re.compile('_\d*_MHz')
    wm = re.compile('ppa_\w*_\d*_qor')
    da = re.compile('\d*\.\d{6}')
    p = re.compile('\d+\.\d+[e-]*\d+')
    t = re.compile('[a-zA-Z0-9]+nm')

    file = open("ppaData.csv", "w")
    writer = csv.writer(file)
    writer.writerow(['Module', 'Tech', 'Width', 'Target Freq', 'Delay', 'Area', 'L Power (nW)', 'D energy (mJ)'])

    for i in range(len(linesCPL)):
        line = linesCPL[i]
        mwm = wm.findall(line)[0][4:-4].split('_')
        freq = int(f.findall(line)[0][1:-4])
        delay = float(cpl.findall(line)[0])
        area = float(da.findall(linesDA[i])[0])
        mod = mwm[0]
        width = int(mwm[1])
        tech = t.findall(line)[0][:-2]
        try: #fix
            power = p.findall(linesP[i])
            lpower = float(power[2])
            denergy = float(power[1])*delay
        except: 
            lpower = 0
            denergy = 0

        writer.writerow([mod, tech, width, freq, delay, area, lpower, denergy])
    file.close()

def getVals(tech, module, var, freq=None):
    ''' for a specified tech, module, and variable/metric
        returns a list of values for that metric in ascending width order with the appropriate units
        works at a specified target frequency or if none is given, uses the synthesis with the min delay for each width
    '''
    
    if (var == 'delay'):
        units = " (ns)"
    elif (var == 'area'):
        units = " (sq microns)"
    elif (var == 'lpower'):
        units = " (nW)"
    elif (var == 'denergy'):
        units = " (pJ)"

    global widths
    metric = []
    widthL = []
    if (freq != None):
        for oneSynth in allSynths:
            if (oneSynth.freq == freq) & (oneSynth.tech == tech) & (oneSynth.module == module):
                widthL += [oneSynth.width]
                osdict = oneSynth._asdict()
                metric += [osdict[var]]
        metric = [x for _, x in sorted(zip(widthL, metric))] # ordering
    else:
        for w in widths:
            m = 100000 # large number to start
            for oneSynth in allSynths:
                if (oneSynth.width == w) & (oneSynth.tech == tech) & (oneSynth.module == module):
                    if (oneSynth.delay < m): 
                        m = oneSynth.delay
                        osdict = oneSynth._asdict()
                        met = osdict[var]
            metric += [met]

    if ('flop' in module) & (var == 'area'):
        metric = [m/2 for m in metric] # since two flops in each module 

    return metric, units

def genLegend(fits, coefs, r2, techcolor):
    ''' generates a list of two legend elements 
        labels line with fit equation and dots with tech and r squared of the fit
    '''

    coefsr = [str(round(c, 3)) for c in coefs]

    eq = ''
    ind = 0
    if 'c' in fits:
        eq += coefsr[ind]
        ind += 1
    if 'l' in fits:
        eq += " + " + coefsr[ind] + "*N"
        ind += 1
    if 's' in fits:
        eq += " + " + coefsr[ind] + "*N^2"
        ind += 1
    if 'g' in fits:
        eq += " + " + coefsr[ind] + "*log2(N)"
        ind += 1
    if 'n' in fits:
        eq += " + " + coefsr[ind] + "*Nlog2(N)"
        ind += 1

    tech, c = techcolor
    legend_elements = [lines.Line2D([0], [0], color=c, label=eq),
                       lines.Line2D([0], [0], color=c, ls='', marker='o', label=tech +'  $R^2$='+ str(round(r2, 4)))]
    return legend_elements

def oneMetricPlot(module, var, freq=None, ax=None, fits='clsgn'):
    ''' module: string module name
        freq: int freq (MHz)
        var: string delay, area, lpower, or denergy
        fits: constant, linear, square, log2, Nlog2
        plots given variable vs width for all matching syntheses with regression
    '''

    if ax is None:
        singlePlot = True
        ax = plt.gca()
    else:
        singlePlot = False

    fullLeg = []
    global techcolors
    global widths
    for combo in techcolors:
        tech, c = combo
        metric, units = getVals(tech, module, var, freq=freq)
        if len(metric) == 5:
            xp, pred, leg = regress(widths, metric, combo, fits)
            fullLeg += leg

            ax.scatter(widths, metric, color=c)
            ax.plot(xp, pred, color=c)

    ax.legend(handles=fullLeg)

    ax.set_xticks(widths)
    ax.set_xlabel("Width (bits)")
    ax.set_ylabel(str.title(var) + units)

    if singlePlot:
        titleStr = "  (target  " + str(freq)+ "MHz)" if freq != None else " (min delay)"
        ax.set_title(module + titleStr)
        plt.show()

def regress(widths, var, techcolor, fits='clsgn'):
    ''' fits a curve to the given points
        returns lists of x and y values to plot that curve and legend elements with the equation
    '''

    funcArr = genFuncs(fits)

    mat = []
    for w in widths:
        row = []
        for func in funcArr:
            row += [func(w)]
        mat += [row]
    
    y = np.array(var, dtype=np.float)
    coefsResid = np.linalg.lstsq(mat, y, rcond=None)
    coefs = coefsResid[0]
    try:
        resid = coefsResid[1][0]
    except:
        resid = 0
    r2 = 1 - resid / (y.size * y.var())

    xp = np.linspace(8, 140, 200)
    pred = []
    for x in xp:
        n = [func(x) for func in funcArr]
        pred += [sum(np.multiply(coefs, n))]

    leg = genLegend(fits, coefs, r2, techcolor)

    return xp, pred, leg

def makeCoefTable(tech):
    ''' not currently in use, may salvage later
        writes CSV with each line containing the coefficients for a regression fit 
        to a particular combination of module, metric, and target frequency
    '''
    file = open("ppaFitting.csv", "w")
    writer = csv.writer(file)
    writer.writerow(['Module', 'Metric', 'Freq', '1', 'N', 'N^2', 'log2(N)', 'Nlog2(N)', 'R^2'])

    for mod in ['add', 'mult', 'comparator', 'shifter']:
        for comb in [['delay', 5000], ['area', 5000], ['area', 10]]:
            var = comb[0]
            freq = comb[1]
            metric, units = getVals(tech, mod, freq, var)
            global widths
            coefs, r2, funcArr = regress(widths, metric)
            row = [mod] + comb + np.ndarray.tolist(coefs) + [r2]
            writer.writerow(row)

    file.close()

def genFuncs(fits='clsgn'):
    ''' helper function for regress()
        returns array of functions with one for each term desired in the regression fit
    '''
    funcArr = []
    if 'c' in fits:
        funcArr += [lambda x: 1]
    if 'l' in fits:
        funcArr += [lambda x: x]
    if 's' in fits:
        funcArr += [lambda x: x**2]
    if 'g' in fits:
        funcArr += [lambda x: np.log2(x)]
    if 'n' in fits:
        funcArr += [lambda x: x*np.log2(x)]
    return funcArr

def noOutliers(freqs, delays, areas):
    ''' returns a pared down list of freqs, delays, and areas 
        cuts out any syntheses in which target freq isn't within 75% of the min delay target to focus on interesting area
        helper function to freqPlot()
    '''
    f=[]
    d=[]
    a=[]
    
    try:
        ind = delays.index(min(delays))
        med = freqs[ind]
        for i in range(len(freqs)):
            norm = freqs[i]/med
            if (norm > 0.25) & (norm<1.75):
                f += [freqs[i]]
                d += [delays[i]]
                a += [areas[i]]
    except: pass
    
    return f, d, a

def freqPlot(tech, mod, width):
    ''' plots delay, area, area*delay, and area*delay^2 for syntheses with specified tech, module, width
    '''
    global allSynths
    freqsL, delaysL, areasL = ([[], []] for i in range(3))
    for oneSynth in allSynths:
        if (mod == oneSynth.module) & (width == oneSynth.width) & (tech == oneSynth.tech):
            ind = (1000/oneSynth.delay < oneSynth.freq) # when delay is within target clock period
            freqsL[ind] += [oneSynth.freq]
            delaysL[ind] += [oneSynth.delay]
            areasL[ind] += [oneSynth.area]

    f, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1, sharex=True)

    for ind in [0,1]:
        areas = areasL[ind]
        delays = delaysL[ind]
        freqs = freqsL[ind]

        if ('flop' in mod): areas = [m/2 for m in areas] # since two flops in each module
        freqs, delays, areas = noOutliers(freqs, delays, areas)

        c = 'blue' if ind else 'green'
        adprod = adprodpow(areas, delays, 2)
        adpow = adprodpow(areas, delays, 3)
        adpow2 = adprodpow(areas, delays, 4)
        ax1.scatter(freqs, delays, color=c)
        ax2.scatter(freqs, areas, color=c)
        ax3.scatter(freqs, adprod, color=c)
        ax4.scatter(freqs, adpow, color=c)
        ax5.scatter(freqs, adpow2, color=c)

    legend_elements = [lines.Line2D([0], [0], color='green', ls='', marker='o', label='timing achieved'),
                       lines.Line2D([0], [0], color='blue', ls='', marker='o', label='slack violated')]

    ax1.legend(handles=legend_elements)
    
    ax4.set_xlabel("Target Freq (MHz)")
    ax1.set_ylabel('Delay (ns)')
    ax2.set_ylabel('Area (sq microns)')
    ax3.set_ylabel('Area * Delay')
    ax4.set_ylabel('Area * $Delay^2$')
    ax1.set_title(mod + '_' + str(width))
    plt.show()

def adprodpow(areas, delays, pow):
    ''' for each value in [areas] returns area*delay^pow
        helper function for freqPlot'''
    result = []

    for i in range(len(areas)):
        result += [(areas[i])*(delays[i])**pow]
    
    return result

def plotPPA(mod, freq=None):
    ''' for the module specified, plots width vs delay, area, leakage power, and dynamic energy with fits
        if no freq specified, uses the synthesis with min delay for each width
        overlays data from both techs
    '''
    fig, axs = plt.subplots(2, 2)
    oneMetricPlot(mod, 'delay', ax=axs[0,0], fits='clg', freq=freq)
    oneMetricPlot(mod, 'area', ax=axs[0,1], fits='s', freq=freq)
    oneMetricPlot(mod, 'lpower', ax=axs[1,0], fits='c', freq=freq)
    oneMetricPlot(mod, 'denergy', ax=axs[1,1], fits='s', freq=freq)
    titleStr = "  (target  " + str(freq)+ "MHz)" if freq != None else " (min delay)"
    plt.suptitle(mod + titleStr)
    plt.show()

Synth = namedtuple("Synth", "module tech width freq delay area lpower denergy")
techcolors = [['sky90', 'green'], ['tsmc28', 'blue']]
widths = [8, 16, 32, 64, 128]
synthsintocsv()

synthsfromcsv('ppaData.csv') # your csv here!

### examples
# oneMetricPlot('add', 'delay')
#freqPlot('sky90', 'add', 8)
#plotPPA('add')