Module MNIST

Expand source code
import CommandMap as CM
import PyTerminal as PT

import importlib
import random
import time
import numpy as np
import io
import torch
import math

import RRAM, DNN

from tkinter import *
from tkinter import font

from PIL import Image, ImageTk


def conf_network(network, verbal):
    """ Configure MNIST network type

    Args:
        network (str): Network type
        verbal (bool): Whether to print the configured network or not.
    """

    # Some global variabls
    global folder_dir
    global weights, targets, sim_preds
    global images, image_len
    global mapping
    global layers

    # Folder directory
    if 'folder_dir' in globals():
        sys.path.remove(folder_dir)

    if network == 'MLP':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_256_fc32_fc10'
    elif network == 'MLP2':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_484_fc64_fc10'
    elif network == 'CONV':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_256_conv16_conv32_fc10'

    sys.path.append(folder_dir)

    # Load network model
    weights = torch.load(folder_dir + '\\weights.pt')

    # Read the mapping and run sanity check against weights
    f = open(folder_dir + '\\mapping.txt', 'r')
    mapping = eval(f.readline())

    # Load input data
    images = np.uint8(torch.load(folder_dir + '\\images.pt'))
    image_len = images.shape[1]

    # Load targets
    targets = np.uint8(torch.load(folder_dir + '\\targets.pt'))

    # Load predictions
    sim_preds = np.uint8(torch.load(folder_dir + '\\sim_preds.pt'))

    # Load the model and configure the network
    import model
    model = importlib.reload(model)

    layer = 0
    layers = []
    trainable_layer = 0
    input_length = image_len
    scale = 1

    for name, module in model.Net.named_modules(model.Net()):
        if isinstance(module, torch.ao.quantization.stubs.QuantStub) or \
           isinstance(module, torch.ao.quantization.stubs.DeQuantStub) or \
           isinstance(module, model.Net):
            continue

        layers.append({})
        #print(f'Name: {name}, Module: {module}')

        # If it's a trainable layer (containing weights)
        if hasattr(module, 'weight'):
            #print(f'\tmapping layer {layer} to rram {mapping[trainable_layer][0]}')
            layers[layer]['rrams'] = mapping[trainable_layer]
            trainable_layer += 1
            if isinstance(module, torch.nn.modules.linear.Linear):
                layers[layer]['weights'] = weights[name + '._packed_params._packed_params'][0]
                scale = int(weights[name + '.scale']/weights[name + '._packed_params._packed_params'][0].q_scale())
            elif isinstance(module, torch.nn.modules.conv.Conv2d):
                layers[layer]['weights'] = weights[name + '.weight']
                scale = int(weights[name + '.scale']/weights[name + '.weight'].q_scale())
        else:
            layers[layer]['rrams'] = []
            layers[layer]['weights'] = []

        if isinstance(module, torch.nn.modules.linear.Linear):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_LINEAR
            layers[layer]['input_length'  ] = module.in_features
            layers[layer]['input_channel' ] = 0
            layers[layer]['kernel_length' ] = module.in_features
            layers[layer]['kernel_channel'] = 0
            layers[layer]['kernel_number' ] = module.out_features
            layers[layer]['stride'        ] = 0
            layers[layer]['output_length' ] = module.out_features
            layers[layer]['output_channel'] = 0
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.conv.Conv2d):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_CONV
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = module.in_channels
            layers[layer]['kernel_length' ] = module.kernel_size[0]
            layers[layer]['kernel_channel'] = module.in_channels
            layers[layer]['kernel_number' ] = module.out_channels
            layers[layer]['stride'        ] = module.stride[0]
            layers[layer]['output_length' ] = int((input_length - module.kernel_size[0])/module.stride[0] + 1)
            layers[layer]['output_channel'] = module.out_channels
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.pooling.MaxPool2d):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_MAXPOOL
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
            layers[layer]['kernel_length' ] = module.kernel_size
            layers[layer]['kernel_channel'] = 1
            layers[layer]['kernel_number' ] = layers[layer-1]['output_channel']
            layers[layer]['stride'        ] = module.stride
            layers[layer]['output_length' ] = int((input_length - module.kernel_size)/module.stride + 1)
            layers[layer]['output_channel'] = layers[layer-1]['output_channel']
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.activation.ReLU):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_RELU
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
            layers[layer]['kernel_length' ] = 0
            layers[layer]['kernel_channel'] = 0
            layers[layer]['kernel_number' ] = 0
            layers[layer]['stride'        ] = 0
            layers[layer]['output_length' ] = input_length
            layers[layer]['output_channel'] = layers[layer-1]['output_channel']
            layers[layer]['output_q_scale'] = scale
            layers[layer]['output_q_zp'   ] = 0

        input_length = layers[layer]['output_length']
        layer += 1

    layers.append({})
    layers[layer]['rrams'         ] = []
    layers[layer]['weights'       ] = []
    layers[layer]['type'          ] = CM.CM_DNN_TYPE_ARGMAX
    layers[layer]['input_length'  ] = input_length
    layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
    layers[layer]['kernel_length' ] = 0
    layers[layer]['kernel_channel'] = 0
    layers[layer]['kernel_number' ] = 0
    layers[layer]['stride'        ] = 0
    layers[layer]['output_length' ] = 0
    layers[layer]['output_channel'] = 0
    layers[layer]['output_q_scale'] = 1
    layers[layer]['output_q_zp'   ] = 0

    DNN.nn_clear(False)
    for layer_index, layer_info in enumerate(layers):
        for index_r, row in enumerate(layer_info['rrams']):
            for index_c, col in enumerate(row):
                DNN.nn_conf_rrams   (str(layer_index), str(index_r), str(index_c), str(layer_info['rrams'][index_r][index_c]), False)
        DNN.nn_conf_type    (str(layer_index), str(layer_info['type']), False)
        DNN.nn_conf_input   (str(layer_index), str(layer_info['input_length']), str(layer_info['input_channel']), False)
        DNN.nn_conf_kernel  (str(layer_index), str(layer_info['kernel_length']), str(layer_info['kernel_channel']), str(layer_info['kernel_number']), str(layer_info['stride']), False)
        DNN.nn_conf_output  (str(layer_index), str(layer_info['output_length']), str(layer_info['output_channel']), False)
        DNN.nn_conf_output_q(str(layer_index), str(layer_info['output_q_scale']), str(layer_info['output_q_zp']), False)
    if verbal:
        DNN.nn_print(True)


def upload_weights(verbal):
    """ Upload MNIST weights on the current network

    Args:
        verbal (bool): Whether to print the uploading progress or not.
    """
    def write_weights_to_rram(weights, type, rram_indices):
        if type == 0: # MLP
            np_weights = np.int8(weights.int_repr())
            (k_number, k_channel) = np_weights.shape

            for tile_ch in range(math.ceil(k_channel/256)):
                for tile_n in range(math.ceil(k_number/32)):
                    ch_offset = tile_ch*256
                    n_offset = tile_n*32
                    print(f'Writing Type \'{type}\' weights[{tile_ch}][{tile_n}] to RRAM #{rram_indices[tile_ch][tile_n]}')
                    RRAM.switch(str(rram_indices[tile_ch][tile_n]), False)

                    if verbal:
                        print('╔════════════╦══════╦══════╦══════╗')
                        print('║ (row, col) ║ Gold ║ Read ║ Diff ║')
                        print('╟────────────╫──────╫──────╫──────╢')
                    for ch in range(min(256, k_channel-tile_ch*256)):
                        for n in range(min(32, k_number-tile_n*32)):
                            golden = int(np_weights[n_offset+n][ch_offset+ch])
                            #print(f'Write {golden:>5} @ ({n:>2}, {ch:>2}) to ({row:>3}, {col:>3})')
                            local_addr = ch * 256 + n
                            RRAM.write_byte_iter(str(local_addr), str(golden), True)
                            if verbal:
                                readout = int(RRAM.read_byte(str(local_addr), '0', '0x1', False))
                                if readout != golden:
                                    print(f'║ ({ch:>3}, {n:>3}) ║ {golden:>4} ║ {readout:>4} ║ {golden-readout:>4} ║')
                    if verbal:
                        print('╚════════════╩══════╩══════╩══════╝')

        elif type == 1: # CONV
            np_weights = np.int8(weights.int_repr())
            (k_number, k_channel, k_width, k_height) = np_weights.shape
            blocks_per_row = int(32 / (k_width*k_height))

            if verbal:
                print('╔════════════╦══════╦══════╦══════╗')
                print('║ (row, col) ║ Gold ║ Read ║ Diff ║')
                print('╟────────────╫──────╫──────╫──────╢')
            for n in range(k_number):
                brow = int(n / blocks_per_row)
                bcol = int(n % blocks_per_row)
                for krow in range(k_width):
                    for kcol in range(k_height):
                        for ch in range(k_channel):
                            row = brow * k_channel + ch
                            col = bcol * (k_width*k_height) + krow * k_width + kcol
                            local_addr = row * 256 + col
                            golden = int(np_weights[n][ch][krow][kcol])
                            #print(f'Write {golden:>5} @ ({n:>2}, {ch:>2}, {krow:>2}, {kcol:>2}) to ({row:>3}, {col:>3})')
                            RRAM.write_byte_iter(str(local_addr), str(golden), True)
                            if verbal:
                                readout = int(RRAM.read_byte(str(local_addr), '0', '0x1', False))
                                if readout != golden:
                                    print(f'║ ({row:>3}, {col:>3}) ║ {golden:>4} ║ {readout:>4} ║ {golden-readout:>4} ║')
            if verbal:
                    print('╚════════════╩══════╩══════╩══════╝')

    for layer_index, layer_info in enumerate(layers):
        write_weights_to_rram(layer_info['weights'], layer_info['type'], layer_info['rrams'])


def upload_image(index, verbal):
    """ Upload a MNIST image on the current network

    Args:
        index (str): Index of the image
        verbal (bool): Whether to print the uploaded image or not.
    """
    image = images[int(index)]

    # Upload the image
    DNN.in_conf_len(str(image_len), True)
    for i in range(image_len):
        for j in range(image_len):
            if image[i][j] != 0:
                DNN.in_fill(str(i*image_len + j), str(image[i][j]), True)

    # Print the image if required
    if verbal:
        DNN.in_print(True)


def test_inference(network, WL_start, WL_end, count, verbal):
    """ Test MNIST inference

    Args:
        network (str): Network type
        WL_start (str): Starting WL Scheme (1~9)
        WL_end (str): Endding WL Scheme (WL_start~9)
        count (str): Number of images that want to inference on
        verbal (bool): Whether to print the result for each image or not.
    """
    verbal = eval(verbal)
    WL_start = int(WL_start)
    WL_end = int(WL_end)
    conf_network(network, False)

    for WL in range(WL_start, WL_end+1):
        print(f'[INFO] WL Scheme: {WL}')
        # Read the image and do th inference
        if verbal:
            print( '╔═══════╦══════╦═════╦════╗')
            print(f'║ Index ║ Gold ║ Sim ║ TC ║')
            print( '╟───────╫──────╫─────╫────╢')
        count = int(count)
        local_targets = np.resize(targets,  count)
        local_sim_preds = np.resize(sim_preds,  count)
        tc_preds = np.empty(count, dtype=np.uint8)
        tick = time.time()
        for index in range(count):
            if not verbal:
                print(f'\r\tImage {index}', end='')

            # Upload image
            upload_image(index, False)

            # Inference
            tc_preds[index] = int(DNN.forward(str(WL), False))

            # Print the result
            if verbal:
                print(f'║ {index:>5} ║ {local_targets[index]:>4} ║ {local_sim_preds[index]:>3} ║ {tc_preds[index]:>2} ║')
        passed_time = time.time()-tick
        if verbal:
            print( '╟───────╨──────╨─────╨────╢')
        else:
            print('')
            print( '╔═════════════════════════╗')
        print(f'║ Pred Acc:   {np.sum(local_sim_preds == local_targets):5d}/{count:5d} ║')
        print(f'║   TC Acc:   {np.sum(tc_preds == local_targets):5d}/{count:5d} ║')
        print(f'║ Duration: {passed_time:9.2f} sec ║')
        print( '╚═════════════════════════╝')


class GUI:
    def __init__(self):
        """ MNIST GUI Demo Class

        """

        # Initialize the main Panel
        self.master = Tk()
        self.master.title("MNIST GUI")

        # Change default font
        font.nametofont("TkDefaultFont").configure(family="Arial", size=12)

        # Sub panels
        frm_controls = Frame(self.master, padx=5)
        frm_canvas = Frame(self.master, padx=5)
        frm_results = Frame(self.master, padx=5)
        frm_controls.pack(side=LEFT)
        frm_canvas.pack(side=LEFT)
        frm_results.pack(side=LEFT)

        # Sub frames in Control Panel
        frm_image_index = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_network = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_wl_scheme = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_operation = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_image_index.pack(pady=5)
        frm_network.pack(pady=5)
        frm_wl_scheme.pack(pady=5)
        frm_operation.pack(pady=5)

        # Sub frames in Result Panel
        frm_golden = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_duration = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_prediction = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_golden.pack(pady=5)
        frm_duration.pack(pady=5)
        frm_prediction.pack(pady=5)

        # In Image Index Frame
        Label(frm_image_index, text='Image Index', width=14, font='Arial 14 bold').pack()

        self.txt_image_index = Entry(frm_image_index, width=6, font='Arial 14')
        self.txt_image_index.pack(side=LEFT, padx=5)

        self.btn_random_icon = ImageTk.PhotoImage(Image.open('Applications/btn_random.png').resize((20, 20)))
        Button(frm_image_index, image=self.btn_random_icon, command=self.image_random).pack(side=RIGHT, padx=5)

        self.btn_load_icon = ImageTk.PhotoImage(Image.open('Applications/btn_load.png').resize((20, 20)))
        Button(frm_image_index, image=self.btn_load_icon, command=self.image_load).pack(side=RIGHT, padx=5)

        # In Network Type Frame
        Label(frm_network, text='Network Type', width=14, font='Arial 14 bold').pack()

        networks = ('MLP', 'MLP2')
        self.network_var = StringVar(value=networks[0])
        OptionMenu(frm_network, self.network_var, *networks).pack(side=LEFT, padx=5)
        self.network_change(False)

        self.btn_display_network_icon = ImageTk.PhotoImage(Image.open('Applications/btn_display_network_icon.png').resize((20, 20)))
        Button(frm_network, image=self.btn_display_network_icon, command=self.network_print).pack(side=RIGHT, padx=5)

        self.btn_config_network_icon = ImageTk.PhotoImage(Image.open('Applications/btn_load.png').resize((20, 20)))
        Button(frm_network, image=self.btn_config_network_icon, command=self.network_change).pack(side=RIGHT, padx=5)

        # In WL Scheme Frame
        Label(frm_wl_scheme, text='WL Scheme', width=14, font='Arial 14 bold').pack()

        self.sld_WL = Scale(frm_wl_scheme, from_=1, to=9, orient=HORIZONTAL)
        self.sld_WL.pack()

        # In Operation Frame
        Label(frm_operation, text='Operation', width=14, font='Arial 14 bold').pack()

        self.btn_clear_icon = ImageTk.PhotoImage(Image.open('Applications/btn_clear.png').resize((20, 20)))
        Button(frm_operation, image=self.btn_clear_icon, command=self.clear).pack(side=LEFT, padx=5)

        Button(frm_operation, text="Inference", command=self.image_inference, font='Arial 12').pack(side=RIGHT, padx=5)

        # In Canvas Panel
        self.old_xy = None
        self.canvas = Canvas(frm_canvas, width=400, height=400, bg='white', borderwidth=10, relief='ridge')
        self.canvas.bind('<B1-Motion>', self.canvas_paint)
        self.canvas.bind('<ButtonRelease-1>', self.canvas_reset)
        self.canvas.pack()

        # In Golden Result Frame
        Label(frm_golden, text='Golden', width=12, font='Arial 14 bold').pack()

        self.text_golden = StringVar(value='N/A')
        Label(frm_golden, textvariable=self.text_golden, font=('Arial', 48)).pack()

        # In Duration Frame
        Label(frm_duration, text='Duration (s)', width=12, font='Arial 14 bold').pack()

        self.text_duration = StringVar(value='N/A')
        Label(frm_duration, textvariable=self.text_duration, font=('Arial', 32)).pack()

        # In Prediction Frame
        Label(frm_prediction, text='Prediction', width=12, font='Arial 14 bold').pack()

        self.text_prediction = StringVar(value='N/A')
        Label(frm_prediction, textvariable=self.text_prediction, font=('Arial', 48)).pack()

        # Make it not resizable and place it at center
        self.master.resizable(False, False)
        self.window_center(self.master)
        self.master.mainloop()

    def window_center(self, window):
        """ Place the window at the center of the monitor

        Args:
            window (Tk or Toplevel): The target window
        """
        window.update_idletasks()
        width = window.winfo_width()
        height = window.winfo_height()
        frm_width = window.winfo_rootx() - window.winfo_x()
        win_width = width + 2 * frm_width
        titlebar_height = window.winfo_rooty() - window.winfo_y()
        win_height = height + titlebar_height + frm_width
        x = window.winfo_screenwidth() // 2 - win_width // 2
        y = window.winfo_screenheight() // 2 - win_height // 2
        window.geometry('{}x{}+{}+{}'.format(width, height, x, y))
        window.deiconify()


    def canvas_paint(self, new_xy):
        """ Callback for canvas painting

        Args:
            new_xy (tuple): new (x, y)
        """
        if self.old_xy:
            self.canvas.create_line(self.old_xy.x, self.old_xy.y, new_xy.x, new_xy.y, width=40, stipple='gray50', capstyle=ROUND)
        self.old_xy = new_xy


    def canvas_reset(self, new_xy):
        """ Callback for canvas reset

        Args:
            new_xy (tuple): new (x, y), but not used in this function

        """
        self.old_xy = None
        self.txt_image_index.delete(0, 'end')


    def canvas_capture(self):
        """ Capture what's on the canvas

        Returns:
            np.uint8: 2D numpy array

        """
        # Capture the image from Canvas
        ps = self.canvas.postscript(colormode='gray')
        image = Image.open(io.BytesIO(ps.encode('utf-8')))
        image = image.convert(mode='L').resize((image_len, image_len))
        image = np.uint8(image.getdata())
        image = np.floor_divide(np.invert(image), 4)
        image = np.reshape(image, (image_len, image_len))
        return image


    def clear(self):
        """ Clean the canvas and other related information

        """
        self.canvas.delete(ALL)
        self.text_golden.set('N/A')
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')
        self.txt_image_index.delete(0, 'end')


    def network_change(self, verbal=True):
        """ Change the network type

        Args:
            verbal (bool, optional): Whether to print the response or not. Defaults to True.

        """
        network = self.network_var.get()
        conf_network(network, False)
        if verbal:
            self.network_print('Network Architecture Updated')


    def network_print(self, win_title='Current Network Architecture'):
        """ Pop up a new window showing the updated network

        Args:
            win_title (str): Title of the popped up window

        """
        win_network = Toplevel()
        win_network.title(win_title)
        Label(win_network, text=DNN.nn_print(False), justify= LEFT, font='Courier 14 bold').pack(fill='both', pady=5)
        Button(win_network, text="Okay", command=win_network.destroy).pack(pady=5)
        self.window_center(win_network)


    def image_load(self):
        """ Load image index from 'txt_image_index'

        """
        index = int(self.txt_image_index.get())

        # Paint the image onto the canvas
        tkimage = np.invert(4*images[index])
        tkimage = Image.fromarray(tkimage)
        tkimage = tkimage.resize((self.canvas.winfo_width(), self.canvas.winfo_height()))
        self.tkimage = ImageTk.PhotoImage(image=tkimage)
        self.canvas.create_image(0, 0, anchor="nw", image=self.tkimage)

        # Load golden and clear prediction/duration
        self.text_golden.set(targets[index])
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')


    def image_random(self):
        """ Choose a random image

        """
        self.txt_image_index.delete(0, 'end')
        self.txt_image_index.insert(0, str(random.randint(0, 10000)))
        self.image_load()


    def image_inference(self):
        """ Upload the image and run inference

        """
        # Clear the prediction and duration first
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')
        self.master.update()

        # Upload the image
        if self.txt_image_index.get() != '':
            upload_image(str(self.txt_image_index.get()), False)
        else:
            image = self.canvas_capture()
            DNN.in_conf_len(str(image_len), True)
            for i in range(image_len):
                for j in range(image_len):
                    if image[i][j] != 0:
                        DNN.in_fill(str(i*image_len+j), str(image[i][j]), True)

        # Print the result
        tick = time.time()
        pred = DNN.forward(str(self.sld_WL.get()), False)
        passed_time = time.time()-tick
        self.text_duration.set(f'{passed_time:.2f}')
        self.text_prediction.set(pred)


def decode(parameters):
    """ Decode the command

    Args:
        parameters (list): Command in List form.
    """
    if   parameters[1] == 'conf_network'  : conf_network(parameters[2], True         )
    elif parameters[1] == 'upload_weights': upload_weights(True                        )
    elif parameters[1] == 'upload_image'  : upload_image  (parameters[2], True         )
    elif parameters[1] == 'test_inference': test_inference(parameters[2], parameters[3], parameters[4], parameters[5], parameters[6])
    elif parameters[1] == 'gui'           : GUI           (                            )
    else: PT.unknown(parameters)

Functions

def conf_network(network, verbal)

Configure MNIST network type

Args

network : str
Network type
verbal : bool
Whether to print the configured network or not.
Expand source code
def conf_network(network, verbal):
    """ Configure MNIST network type

    Args:
        network (str): Network type
        verbal (bool): Whether to print the configured network or not.
    """

    # Some global variabls
    global folder_dir
    global weights, targets, sim_preds
    global images, image_len
    global mapping
    global layers

    # Folder directory
    if 'folder_dir' in globals():
        sys.path.remove(folder_dir)

    if network == 'MLP':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_256_fc32_fc10'
    elif network == 'MLP2':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_484_fc64_fc10'
    elif network == 'CONV':
        folder_dir = 'D:\Dropbox (GaTech)\GaTech\ICSRL\Projects\9. RRAM\Evaluation Board\MNIST\data_256_conv16_conv32_fc10'

    sys.path.append(folder_dir)

    # Load network model
    weights = torch.load(folder_dir + '\\weights.pt')

    # Read the mapping and run sanity check against weights
    f = open(folder_dir + '\\mapping.txt', 'r')
    mapping = eval(f.readline())

    # Load input data
    images = np.uint8(torch.load(folder_dir + '\\images.pt'))
    image_len = images.shape[1]

    # Load targets
    targets = np.uint8(torch.load(folder_dir + '\\targets.pt'))

    # Load predictions
    sim_preds = np.uint8(torch.load(folder_dir + '\\sim_preds.pt'))

    # Load the model and configure the network
    import model
    model = importlib.reload(model)

    layer = 0
    layers = []
    trainable_layer = 0
    input_length = image_len
    scale = 1

    for name, module in model.Net.named_modules(model.Net()):
        if isinstance(module, torch.ao.quantization.stubs.QuantStub) or \
           isinstance(module, torch.ao.quantization.stubs.DeQuantStub) or \
           isinstance(module, model.Net):
            continue

        layers.append({})
        #print(f'Name: {name}, Module: {module}')

        # If it's a trainable layer (containing weights)
        if hasattr(module, 'weight'):
            #print(f'\tmapping layer {layer} to rram {mapping[trainable_layer][0]}')
            layers[layer]['rrams'] = mapping[trainable_layer]
            trainable_layer += 1
            if isinstance(module, torch.nn.modules.linear.Linear):
                layers[layer]['weights'] = weights[name + '._packed_params._packed_params'][0]
                scale = int(weights[name + '.scale']/weights[name + '._packed_params._packed_params'][0].q_scale())
            elif isinstance(module, torch.nn.modules.conv.Conv2d):
                layers[layer]['weights'] = weights[name + '.weight']
                scale = int(weights[name + '.scale']/weights[name + '.weight'].q_scale())
        else:
            layers[layer]['rrams'] = []
            layers[layer]['weights'] = []

        if isinstance(module, torch.nn.modules.linear.Linear):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_LINEAR
            layers[layer]['input_length'  ] = module.in_features
            layers[layer]['input_channel' ] = 0
            layers[layer]['kernel_length' ] = module.in_features
            layers[layer]['kernel_channel'] = 0
            layers[layer]['kernel_number' ] = module.out_features
            layers[layer]['stride'        ] = 0
            layers[layer]['output_length' ] = module.out_features
            layers[layer]['output_channel'] = 0
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.conv.Conv2d):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_CONV
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = module.in_channels
            layers[layer]['kernel_length' ] = module.kernel_size[0]
            layers[layer]['kernel_channel'] = module.in_channels
            layers[layer]['kernel_number' ] = module.out_channels
            layers[layer]['stride'        ] = module.stride[0]
            layers[layer]['output_length' ] = int((input_length - module.kernel_size[0])/module.stride[0] + 1)
            layers[layer]['output_channel'] = module.out_channels
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.pooling.MaxPool2d):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_MAXPOOL
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
            layers[layer]['kernel_length' ] = module.kernel_size
            layers[layer]['kernel_channel'] = 1
            layers[layer]['kernel_number' ] = layers[layer-1]['output_channel']
            layers[layer]['stride'        ] = module.stride
            layers[layer]['output_length' ] = int((input_length - module.kernel_size)/module.stride + 1)
            layers[layer]['output_channel'] = layers[layer-1]['output_channel']
            layers[layer]['output_q_scale'] = 1
            layers[layer]['output_q_zp'   ] = 0
        elif isinstance(module, torch.nn.modules.activation.ReLU):
            layers[layer]['type'          ] = CM.CM_DNN_TYPE_RELU
            layers[layer]['input_length'  ] = input_length
            layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
            layers[layer]['kernel_length' ] = 0
            layers[layer]['kernel_channel'] = 0
            layers[layer]['kernel_number' ] = 0
            layers[layer]['stride'        ] = 0
            layers[layer]['output_length' ] = input_length
            layers[layer]['output_channel'] = layers[layer-1]['output_channel']
            layers[layer]['output_q_scale'] = scale
            layers[layer]['output_q_zp'   ] = 0

        input_length = layers[layer]['output_length']
        layer += 1

    layers.append({})
    layers[layer]['rrams'         ] = []
    layers[layer]['weights'       ] = []
    layers[layer]['type'          ] = CM.CM_DNN_TYPE_ARGMAX
    layers[layer]['input_length'  ] = input_length
    layers[layer]['input_channel' ] = layers[layer-1]['output_channel']
    layers[layer]['kernel_length' ] = 0
    layers[layer]['kernel_channel'] = 0
    layers[layer]['kernel_number' ] = 0
    layers[layer]['stride'        ] = 0
    layers[layer]['output_length' ] = 0
    layers[layer]['output_channel'] = 0
    layers[layer]['output_q_scale'] = 1
    layers[layer]['output_q_zp'   ] = 0

    DNN.nn_clear(False)
    for layer_index, layer_info in enumerate(layers):
        for index_r, row in enumerate(layer_info['rrams']):
            for index_c, col in enumerate(row):
                DNN.nn_conf_rrams   (str(layer_index), str(index_r), str(index_c), str(layer_info['rrams'][index_r][index_c]), False)
        DNN.nn_conf_type    (str(layer_index), str(layer_info['type']), False)
        DNN.nn_conf_input   (str(layer_index), str(layer_info['input_length']), str(layer_info['input_channel']), False)
        DNN.nn_conf_kernel  (str(layer_index), str(layer_info['kernel_length']), str(layer_info['kernel_channel']), str(layer_info['kernel_number']), str(layer_info['stride']), False)
        DNN.nn_conf_output  (str(layer_index), str(layer_info['output_length']), str(layer_info['output_channel']), False)
        DNN.nn_conf_output_q(str(layer_index), str(layer_info['output_q_scale']), str(layer_info['output_q_zp']), False)
    if verbal:
        DNN.nn_print(True)
def decode(parameters)

Decode the command

Args

parameters : list
Command in List form.
Expand source code
def decode(parameters):
    """ Decode the command

    Args:
        parameters (list): Command in List form.
    """
    if   parameters[1] == 'conf_network'  : conf_network(parameters[2], True         )
    elif parameters[1] == 'upload_weights': upload_weights(True                        )
    elif parameters[1] == 'upload_image'  : upload_image  (parameters[2], True         )
    elif parameters[1] == 'test_inference': test_inference(parameters[2], parameters[3], parameters[4], parameters[5], parameters[6])
    elif parameters[1] == 'gui'           : GUI           (                            )
    else: PT.unknown(parameters)
def test_inference(network, WL_start, WL_end, count, verbal)

Test MNIST inference

Args

network : str
Network type
WL_start : str
Starting WL Scheme (1~9)
WL_end : str
Endding WL Scheme (WL_start~9)
count : str
Number of images that want to inference on
verbal : bool
Whether to print the result for each image or not.
Expand source code
def test_inference(network, WL_start, WL_end, count, verbal):
    """ Test MNIST inference

    Args:
        network (str): Network type
        WL_start (str): Starting WL Scheme (1~9)
        WL_end (str): Endding WL Scheme (WL_start~9)
        count (str): Number of images that want to inference on
        verbal (bool): Whether to print the result for each image or not.
    """
    verbal = eval(verbal)
    WL_start = int(WL_start)
    WL_end = int(WL_end)
    conf_network(network, False)

    for WL in range(WL_start, WL_end+1):
        print(f'[INFO] WL Scheme: {WL}')
        # Read the image and do th inference
        if verbal:
            print( '╔═══════╦══════╦═════╦════╗')
            print(f'║ Index ║ Gold ║ Sim ║ TC ║')
            print( '╟───────╫──────╫─────╫────╢')
        count = int(count)
        local_targets = np.resize(targets,  count)
        local_sim_preds = np.resize(sim_preds,  count)
        tc_preds = np.empty(count, dtype=np.uint8)
        tick = time.time()
        for index in range(count):
            if not verbal:
                print(f'\r\tImage {index}', end='')

            # Upload image
            upload_image(index, False)

            # Inference
            tc_preds[index] = int(DNN.forward(str(WL), False))

            # Print the result
            if verbal:
                print(f'║ {index:>5} ║ {local_targets[index]:>4} ║ {local_sim_preds[index]:>3} ║ {tc_preds[index]:>2} ║')
        passed_time = time.time()-tick
        if verbal:
            print( '╟───────╨──────╨─────╨────╢')
        else:
            print('')
            print( '╔═════════════════════════╗')
        print(f'║ Pred Acc:   {np.sum(local_sim_preds == local_targets):5d}/{count:5d} ║')
        print(f'║   TC Acc:   {np.sum(tc_preds == local_targets):5d}/{count:5d} ║')
        print(f'║ Duration: {passed_time:9.2f} sec ║')
        print( '╚═════════════════════════╝')
def upload_image(index, verbal)

Upload a MNIST image on the current network

Args

index : str
Index of the image
verbal : bool
Whether to print the uploaded image or not.
Expand source code
def upload_image(index, verbal):
    """ Upload a MNIST image on the current network

    Args:
        index (str): Index of the image
        verbal (bool): Whether to print the uploaded image or not.
    """
    image = images[int(index)]

    # Upload the image
    DNN.in_conf_len(str(image_len), True)
    for i in range(image_len):
        for j in range(image_len):
            if image[i][j] != 0:
                DNN.in_fill(str(i*image_len + j), str(image[i][j]), True)

    # Print the image if required
    if verbal:
        DNN.in_print(True)
def upload_weights(verbal)

Upload MNIST weights on the current network

Args

verbal : bool
Whether to print the uploading progress or not.
Expand source code
def upload_weights(verbal):
    """ Upload MNIST weights on the current network

    Args:
        verbal (bool): Whether to print the uploading progress or not.
    """
    def write_weights_to_rram(weights, type, rram_indices):
        if type == 0: # MLP
            np_weights = np.int8(weights.int_repr())
            (k_number, k_channel) = np_weights.shape

            for tile_ch in range(math.ceil(k_channel/256)):
                for tile_n in range(math.ceil(k_number/32)):
                    ch_offset = tile_ch*256
                    n_offset = tile_n*32
                    print(f'Writing Type \'{type}\' weights[{tile_ch}][{tile_n}] to RRAM #{rram_indices[tile_ch][tile_n]}')
                    RRAM.switch(str(rram_indices[tile_ch][tile_n]), False)

                    if verbal:
                        print('╔════════════╦══════╦══════╦══════╗')
                        print('║ (row, col) ║ Gold ║ Read ║ Diff ║')
                        print('╟────────────╫──────╫──────╫──────╢')
                    for ch in range(min(256, k_channel-tile_ch*256)):
                        for n in range(min(32, k_number-tile_n*32)):
                            golden = int(np_weights[n_offset+n][ch_offset+ch])
                            #print(f'Write {golden:>5} @ ({n:>2}, {ch:>2}) to ({row:>3}, {col:>3})')
                            local_addr = ch * 256 + n
                            RRAM.write_byte_iter(str(local_addr), str(golden), True)
                            if verbal:
                                readout = int(RRAM.read_byte(str(local_addr), '0', '0x1', False))
                                if readout != golden:
                                    print(f'║ ({ch:>3}, {n:>3}) ║ {golden:>4} ║ {readout:>4} ║ {golden-readout:>4} ║')
                    if verbal:
                        print('╚════════════╩══════╩══════╩══════╝')

        elif type == 1: # CONV
            np_weights = np.int8(weights.int_repr())
            (k_number, k_channel, k_width, k_height) = np_weights.shape
            blocks_per_row = int(32 / (k_width*k_height))

            if verbal:
                print('╔════════════╦══════╦══════╦══════╗')
                print('║ (row, col) ║ Gold ║ Read ║ Diff ║')
                print('╟────────────╫──────╫──────╫──────╢')
            for n in range(k_number):
                brow = int(n / blocks_per_row)
                bcol = int(n % blocks_per_row)
                for krow in range(k_width):
                    for kcol in range(k_height):
                        for ch in range(k_channel):
                            row = brow * k_channel + ch
                            col = bcol * (k_width*k_height) + krow * k_width + kcol
                            local_addr = row * 256 + col
                            golden = int(np_weights[n][ch][krow][kcol])
                            #print(f'Write {golden:>5} @ ({n:>2}, {ch:>2}, {krow:>2}, {kcol:>2}) to ({row:>3}, {col:>3})')
                            RRAM.write_byte_iter(str(local_addr), str(golden), True)
                            if verbal:
                                readout = int(RRAM.read_byte(str(local_addr), '0', '0x1', False))
                                if readout != golden:
                                    print(f'║ ({row:>3}, {col:>3}) ║ {golden:>4} ║ {readout:>4} ║ {golden-readout:>4} ║')
            if verbal:
                    print('╚════════════╩══════╩══════╩══════╝')

    for layer_index, layer_info in enumerate(layers):
        write_weights_to_rram(layer_info['weights'], layer_info['type'], layer_info['rrams'])

Classes

class GUI

MNIST GUI Demo Class

Expand source code
class GUI:
    def __init__(self):
        """ MNIST GUI Demo Class

        """

        # Initialize the main Panel
        self.master = Tk()
        self.master.title("MNIST GUI")

        # Change default font
        font.nametofont("TkDefaultFont").configure(family="Arial", size=12)

        # Sub panels
        frm_controls = Frame(self.master, padx=5)
        frm_canvas = Frame(self.master, padx=5)
        frm_results = Frame(self.master, padx=5)
        frm_controls.pack(side=LEFT)
        frm_canvas.pack(side=LEFT)
        frm_results.pack(side=LEFT)

        # Sub frames in Control Panel
        frm_image_index = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_network = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_wl_scheme = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_operation = Frame(frm_controls, borderwidth=5, relief='ridge')
        frm_image_index.pack(pady=5)
        frm_network.pack(pady=5)
        frm_wl_scheme.pack(pady=5)
        frm_operation.pack(pady=5)

        # Sub frames in Result Panel
        frm_golden = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_duration = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_prediction = Frame(frm_results, borderwidth=5, relief='ridge')
        frm_golden.pack(pady=5)
        frm_duration.pack(pady=5)
        frm_prediction.pack(pady=5)

        # In Image Index Frame
        Label(frm_image_index, text='Image Index', width=14, font='Arial 14 bold').pack()

        self.txt_image_index = Entry(frm_image_index, width=6, font='Arial 14')
        self.txt_image_index.pack(side=LEFT, padx=5)

        self.btn_random_icon = ImageTk.PhotoImage(Image.open('Applications/btn_random.png').resize((20, 20)))
        Button(frm_image_index, image=self.btn_random_icon, command=self.image_random).pack(side=RIGHT, padx=5)

        self.btn_load_icon = ImageTk.PhotoImage(Image.open('Applications/btn_load.png').resize((20, 20)))
        Button(frm_image_index, image=self.btn_load_icon, command=self.image_load).pack(side=RIGHT, padx=5)

        # In Network Type Frame
        Label(frm_network, text='Network Type', width=14, font='Arial 14 bold').pack()

        networks = ('MLP', 'MLP2')
        self.network_var = StringVar(value=networks[0])
        OptionMenu(frm_network, self.network_var, *networks).pack(side=LEFT, padx=5)
        self.network_change(False)

        self.btn_display_network_icon = ImageTk.PhotoImage(Image.open('Applications/btn_display_network_icon.png').resize((20, 20)))
        Button(frm_network, image=self.btn_display_network_icon, command=self.network_print).pack(side=RIGHT, padx=5)

        self.btn_config_network_icon = ImageTk.PhotoImage(Image.open('Applications/btn_load.png').resize((20, 20)))
        Button(frm_network, image=self.btn_config_network_icon, command=self.network_change).pack(side=RIGHT, padx=5)

        # In WL Scheme Frame
        Label(frm_wl_scheme, text='WL Scheme', width=14, font='Arial 14 bold').pack()

        self.sld_WL = Scale(frm_wl_scheme, from_=1, to=9, orient=HORIZONTAL)
        self.sld_WL.pack()

        # In Operation Frame
        Label(frm_operation, text='Operation', width=14, font='Arial 14 bold').pack()

        self.btn_clear_icon = ImageTk.PhotoImage(Image.open('Applications/btn_clear.png').resize((20, 20)))
        Button(frm_operation, image=self.btn_clear_icon, command=self.clear).pack(side=LEFT, padx=5)

        Button(frm_operation, text="Inference", command=self.image_inference, font='Arial 12').pack(side=RIGHT, padx=5)

        # In Canvas Panel
        self.old_xy = None
        self.canvas = Canvas(frm_canvas, width=400, height=400, bg='white', borderwidth=10, relief='ridge')
        self.canvas.bind('<B1-Motion>', self.canvas_paint)
        self.canvas.bind('<ButtonRelease-1>', self.canvas_reset)
        self.canvas.pack()

        # In Golden Result Frame
        Label(frm_golden, text='Golden', width=12, font='Arial 14 bold').pack()

        self.text_golden = StringVar(value='N/A')
        Label(frm_golden, textvariable=self.text_golden, font=('Arial', 48)).pack()

        # In Duration Frame
        Label(frm_duration, text='Duration (s)', width=12, font='Arial 14 bold').pack()

        self.text_duration = StringVar(value='N/A')
        Label(frm_duration, textvariable=self.text_duration, font=('Arial', 32)).pack()

        # In Prediction Frame
        Label(frm_prediction, text='Prediction', width=12, font='Arial 14 bold').pack()

        self.text_prediction = StringVar(value='N/A')
        Label(frm_prediction, textvariable=self.text_prediction, font=('Arial', 48)).pack()

        # Make it not resizable and place it at center
        self.master.resizable(False, False)
        self.window_center(self.master)
        self.master.mainloop()

    def window_center(self, window):
        """ Place the window at the center of the monitor

        Args:
            window (Tk or Toplevel): The target window
        """
        window.update_idletasks()
        width = window.winfo_width()
        height = window.winfo_height()
        frm_width = window.winfo_rootx() - window.winfo_x()
        win_width = width + 2 * frm_width
        titlebar_height = window.winfo_rooty() - window.winfo_y()
        win_height = height + titlebar_height + frm_width
        x = window.winfo_screenwidth() // 2 - win_width // 2
        y = window.winfo_screenheight() // 2 - win_height // 2
        window.geometry('{}x{}+{}+{}'.format(width, height, x, y))
        window.deiconify()


    def canvas_paint(self, new_xy):
        """ Callback for canvas painting

        Args:
            new_xy (tuple): new (x, y)
        """
        if self.old_xy:
            self.canvas.create_line(self.old_xy.x, self.old_xy.y, new_xy.x, new_xy.y, width=40, stipple='gray50', capstyle=ROUND)
        self.old_xy = new_xy


    def canvas_reset(self, new_xy):
        """ Callback for canvas reset

        Args:
            new_xy (tuple): new (x, y), but not used in this function

        """
        self.old_xy = None
        self.txt_image_index.delete(0, 'end')


    def canvas_capture(self):
        """ Capture what's on the canvas

        Returns:
            np.uint8: 2D numpy array

        """
        # Capture the image from Canvas
        ps = self.canvas.postscript(colormode='gray')
        image = Image.open(io.BytesIO(ps.encode('utf-8')))
        image = image.convert(mode='L').resize((image_len, image_len))
        image = np.uint8(image.getdata())
        image = np.floor_divide(np.invert(image), 4)
        image = np.reshape(image, (image_len, image_len))
        return image


    def clear(self):
        """ Clean the canvas and other related information

        """
        self.canvas.delete(ALL)
        self.text_golden.set('N/A')
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')
        self.txt_image_index.delete(0, 'end')


    def network_change(self, verbal=True):
        """ Change the network type

        Args:
            verbal (bool, optional): Whether to print the response or not. Defaults to True.

        """
        network = self.network_var.get()
        conf_network(network, False)
        if verbal:
            self.network_print('Network Architecture Updated')


    def network_print(self, win_title='Current Network Architecture'):
        """ Pop up a new window showing the updated network

        Args:
            win_title (str): Title of the popped up window

        """
        win_network = Toplevel()
        win_network.title(win_title)
        Label(win_network, text=DNN.nn_print(False), justify= LEFT, font='Courier 14 bold').pack(fill='both', pady=5)
        Button(win_network, text="Okay", command=win_network.destroy).pack(pady=5)
        self.window_center(win_network)


    def image_load(self):
        """ Load image index from 'txt_image_index'

        """
        index = int(self.txt_image_index.get())

        # Paint the image onto the canvas
        tkimage = np.invert(4*images[index])
        tkimage = Image.fromarray(tkimage)
        tkimage = tkimage.resize((self.canvas.winfo_width(), self.canvas.winfo_height()))
        self.tkimage = ImageTk.PhotoImage(image=tkimage)
        self.canvas.create_image(0, 0, anchor="nw", image=self.tkimage)

        # Load golden and clear prediction/duration
        self.text_golden.set(targets[index])
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')


    def image_random(self):
        """ Choose a random image

        """
        self.txt_image_index.delete(0, 'end')
        self.txt_image_index.insert(0, str(random.randint(0, 10000)))
        self.image_load()


    def image_inference(self):
        """ Upload the image and run inference

        """
        # Clear the prediction and duration first
        self.text_prediction.set('N/A')
        self.text_duration.set('N/A')
        self.master.update()

        # Upload the image
        if self.txt_image_index.get() != '':
            upload_image(str(self.txt_image_index.get()), False)
        else:
            image = self.canvas_capture()
            DNN.in_conf_len(str(image_len), True)
            for i in range(image_len):
                for j in range(image_len):
                    if image[i][j] != 0:
                        DNN.in_fill(str(i*image_len+j), str(image[i][j]), True)

        # Print the result
        tick = time.time()
        pred = DNN.forward(str(self.sld_WL.get()), False)
        passed_time = time.time()-tick
        self.text_duration.set(f'{passed_time:.2f}')
        self.text_prediction.set(pred)

Methods

def canvas_capture(self)

Capture what's on the canvas

Returns

np.uint8
2D numpy array
Expand source code
def canvas_capture(self):
    """ Capture what's on the canvas

    Returns:
        np.uint8: 2D numpy array

    """
    # Capture the image from Canvas
    ps = self.canvas.postscript(colormode='gray')
    image = Image.open(io.BytesIO(ps.encode('utf-8')))
    image = image.convert(mode='L').resize((image_len, image_len))
    image = np.uint8(image.getdata())
    image = np.floor_divide(np.invert(image), 4)
    image = np.reshape(image, (image_len, image_len))
    return image
def canvas_paint(self, new_xy)

Callback for canvas painting

Args

new_xy : tuple
new (x, y)
Expand source code
def canvas_paint(self, new_xy):
    """ Callback for canvas painting

    Args:
        new_xy (tuple): new (x, y)
    """
    if self.old_xy:
        self.canvas.create_line(self.old_xy.x, self.old_xy.y, new_xy.x, new_xy.y, width=40, stipple='gray50', capstyle=ROUND)
    self.old_xy = new_xy
def canvas_reset(self, new_xy)

Callback for canvas reset

Args

new_xy : tuple
new (x, y), but not used in this function
Expand source code
def canvas_reset(self, new_xy):
    """ Callback for canvas reset

    Args:
        new_xy (tuple): new (x, y), but not used in this function

    """
    self.old_xy = None
    self.txt_image_index.delete(0, 'end')
def clear(self)

Clean the canvas and other related information

Expand source code
def clear(self):
    """ Clean the canvas and other related information

    """
    self.canvas.delete(ALL)
    self.text_golden.set('N/A')
    self.text_prediction.set('N/A')
    self.text_duration.set('N/A')
    self.txt_image_index.delete(0, 'end')
def image_inference(self)

Upload the image and run inference

Expand source code
def image_inference(self):
    """ Upload the image and run inference

    """
    # Clear the prediction and duration first
    self.text_prediction.set('N/A')
    self.text_duration.set('N/A')
    self.master.update()

    # Upload the image
    if self.txt_image_index.get() != '':
        upload_image(str(self.txt_image_index.get()), False)
    else:
        image = self.canvas_capture()
        DNN.in_conf_len(str(image_len), True)
        for i in range(image_len):
            for j in range(image_len):
                if image[i][j] != 0:
                    DNN.in_fill(str(i*image_len+j), str(image[i][j]), True)

    # Print the result
    tick = time.time()
    pred = DNN.forward(str(self.sld_WL.get()), False)
    passed_time = time.time()-tick
    self.text_duration.set(f'{passed_time:.2f}')
    self.text_prediction.set(pred)
def image_load(self)

Load image index from 'txt_image_index'

Expand source code
def image_load(self):
    """ Load image index from 'txt_image_index'

    """
    index = int(self.txt_image_index.get())

    # Paint the image onto the canvas
    tkimage = np.invert(4*images[index])
    tkimage = Image.fromarray(tkimage)
    tkimage = tkimage.resize((self.canvas.winfo_width(), self.canvas.winfo_height()))
    self.tkimage = ImageTk.PhotoImage(image=tkimage)
    self.canvas.create_image(0, 0, anchor="nw", image=self.tkimage)

    # Load golden and clear prediction/duration
    self.text_golden.set(targets[index])
    self.text_prediction.set('N/A')
    self.text_duration.set('N/A')
def image_random(self)

Choose a random image

Expand source code
def image_random(self):
    """ Choose a random image

    """
    self.txt_image_index.delete(0, 'end')
    self.txt_image_index.insert(0, str(random.randint(0, 10000)))
    self.image_load()
def network_change(self, verbal=True)

Change the network type

Args

verbal : bool, optional
Whether to print the response or not. Defaults to True.
Expand source code
def network_change(self, verbal=True):
    """ Change the network type

    Args:
        verbal (bool, optional): Whether to print the response or not. Defaults to True.

    """
    network = self.network_var.get()
    conf_network(network, False)
    if verbal:
        self.network_print('Network Architecture Updated')
def network_print(self, win_title='Current Network Architecture')

Pop up a new window showing the updated network

Args

win_title : str
Title of the popped up window
Expand source code
def network_print(self, win_title='Current Network Architecture'):
    """ Pop up a new window showing the updated network

    Args:
        win_title (str): Title of the popped up window

    """
    win_network = Toplevel()
    win_network.title(win_title)
    Label(win_network, text=DNN.nn_print(False), justify= LEFT, font='Courier 14 bold').pack(fill='both', pady=5)
    Button(win_network, text="Okay", command=win_network.destroy).pack(pady=5)
    self.window_center(win_network)
def window_center(self, window)

Place the window at the center of the monitor

Args

window : Tk or Toplevel
The target window
Expand source code
def window_center(self, window):
    """ Place the window at the center of the monitor

    Args:
        window (Tk or Toplevel): The target window
    """
    window.update_idletasks()
    width = window.winfo_width()
    height = window.winfo_height()
    frm_width = window.winfo_rootx() - window.winfo_x()
    win_width = width + 2 * frm_width
    titlebar_height = window.winfo_rooty() - window.winfo_y()
    win_height = height + titlebar_height + frm_width
    x = window.winfo_screenwidth() // 2 - win_width // 2
    y = window.winfo_screenheight() // 2 - win_height // 2
    window.geometry('{}x{}+{}+{}'.format(width, height, x, y))
    window.deiconify()