#!/usr/bin/python3 # -*- coding: utf-8 -*- # library modules from math import ceil import json import time import os import threading # External library modules import tensorflow as tf import numpy as np # local modules from data import LSVRC2010 import logs class AlexNet: """ A tensorflow implementation of the paper: `AlexNet `_ """ def __init__(self, path, batch_size, resume): """ Build the AlexNet model """ self.logger = logs.get_logger() self.resume = resume self.path = path self.batch_size = batch_size self.lsvrc2010 = LSVRC2010(self.path, batch_size) self.num_classes = len(self.lsvrc2010.wnid2label) self.lr = 0.001 self.momentum = 0.9 self.lambd = tf.constant(0.0005, name='lambda') self.input_shape = (None, 227, 227, 3) self.output_shape = (None, self.num_classes) self.logger.info("Creating placeholders for graph...") self.create_tf_placeholders() self.logger.info("Creating variables for graph...") self.create_tf_variables() self.logger.info("Initialize hyper parameters...") self.hyper_param = {} self.init_hyper_param() def create_tf_placeholders(self): """ Create placeholders for the graph. The input for these will be given while training or testing. """ self.input_image = tf.placeholder(tf.float32, shape=self.input_shape, name='input_image') self.labels = tf.placeholder(tf.float32, shape=self.output_shape, name='output') self.learning_rate = tf.placeholder(tf.float32, shape=(), name='learning_rate') self.dropout = tf.placeholder(tf.float32, shape=(), name='dropout') def create_tf_variables(self): """ Create variables for epoch, batch and global step """ self.global_step = tf.Variable(0, name='global_step', trainable=False) self.cur_epoch = tf.Variable(0, name='epoch', trainable=False) self.cur_batch = tf.Variable(0, name='batch', trainable=False) self.increment_epoch_op = tf.assign(self.cur_epoch, self.cur_epoch+1) self.increment_batch_op = tf.assign(self.cur_batch, self.cur_batch+1) self.init_batch_op = tf.assign(self.cur_batch, 0) def init_hyper_param(self): """ Store the hyper parameters. For each layer store number of filters(kernels) and filter size. If it's a fully connected layer then store the number of neurons. """ with open('hparam.json') as f: self.hyper_param = json.load(f) def get_filter(self, layer_num, layer_name): """ :param layer_num: Indicates the layer number in the graph :type layer_num: int :param layer_name: Name of the filter """ layer = 'L' + str(layer_num) filter_height, filter_width, in_channels = self.hyper_param[layer]['filter_size'] out_channels = self.hyper_param[layer]['filters'] return tf.Variable(tf.truncated_normal( [filter_height, filter_width, in_channels, out_channels], dtype = tf.float32, stddev = 1e-2), name = layer_name) def get_strides(self, layer_num): """ :param layer_num: Indicates the layer number in the graph :type layer_num: int """ layer = 'L' + str(layer_num) stride = self.hyper_param[layer]['stride'] strides = [1, stride, stride, 1] return strides def get_bias(self, layer_num, value=0.0): """ Get the bias variable for current layer :param layer_num: Indicates the layer number in the graph :type layer_num: int """ layer = 'L' + str(layer_num) initial = tf.constant(value, shape=[self.hyper_param[layer]['filters']], name='C' + str(layer_num)) return tf.Variable(initial, name='B' + str(layer_num)) @property def l2_loss(self): """ Compute the l2 loss for all the weights """ conv_bias_names = ['B' + str(i) for i in range(1, 6)] weights = [] for v in tf.trainable_variables(): if 'biases' in v.name: continue if v.name.split(':')[0] in conv_bias_names: continue weights.append(v) return self.lambd * tf.add_n([tf.nn.l2_loss(weight) for weight in weights]) def build_graph(self): """ Build the tensorflow graph for AlexNet. First 5 layers are Convolutional layers. Out of which first 2 and last layer will be followed by *max pooling* layers. Next 2 layers are fully connected layers. L1_conv -> L1_MP -> L2_conv -> L2_MP -> L3_conv -> L4_conv -> L5_conv -> L5_MP -> L6_FC -> L7_FC Where L1_conv -> Convolutional layer 1 L5_MP -> Max pooling layer 5 L7_FC -> Fully Connected layer 7 Use `tf.nn.conv2d` to initialize the filters so as to reduce training time and `tf.layers.max_pooling2d` as we don't need to initialize in the pooling layer. """ # Layer 1 Convolutional layer filter1 = self.get_filter(1, 'L1_filter') l1_conv = tf.nn.conv2d(self.input_image, filter1, self.get_strides(1), padding = self.hyper_param['L1']['padding'], name='L1_conv') l1_conv = tf.add(l1_conv, self.get_bias(1)) l1_conv = tf.nn.local_response_normalization(l1_conv, depth_radius=5, bias=2, alpha=1e-4, beta=.75) l1_conv = tf.nn.relu(l1_conv) # Layer 1 Max Pooling layer l1_MP = tf.layers.max_pooling2d(l1_conv, self.hyper_param['L1_MP']['filter_size'], self.hyper_param['L1_MP']['stride'], name='L1_MP') # Layer 2 Convolutional layer filter2 = self.get_filter(2, 'L2_filter') l2_conv = tf.nn.conv2d(l1_MP, filter2, self.get_strides(2), padding = self.hyper_param['L2']['padding'], name='L2_conv') l2_conv = tf.add(l2_conv, self.get_bias(2, 1.0)) l2_conv = tf.nn.local_response_normalization(l2_conv, depth_radius=5, bias=2, alpha=1e-4, beta=.75) l2_conv = tf.nn.relu(l2_conv) # Layer 2 Max Pooling layer l2_MP = tf.layers.max_pooling2d(l2_conv, self.hyper_param['L2_MP']['filter_size'], self.hyper_param['L2_MP']['stride'], name='L2_MP') # Layer 3 Convolutional layer filter3 = self.get_filter(3, 'L3_filter') l3_conv = tf.nn.conv2d(l2_MP, filter3, self.get_strides(3), padding = self.hyper_param['L3']['padding'], name='L3_conv') l3_conv = tf.add(l3_conv, self.get_bias(3)) l3_conv = tf.nn.relu(l3_conv) # Layer 4 Convolutional layer filter4 = self.get_filter(4, 'L4_filter') l4_conv = tf.nn.conv2d(l3_conv, filter4, self.get_strides(4), padding = self.hyper_param['L4']['padding'], name='L4_conv') l4_conv = tf.add(l4_conv, self.get_bias(4, 1.0)) l4_conv = tf.nn.relu(l4_conv) # Layer 5 Convolutional layer filter5 = self.get_filter(5, 'L5_filter') l5_conv = tf.nn.conv2d(l4_conv, filter5, self.get_strides(5), padding = self.hyper_param['L5']['padding'], name='L5_conv') l5_conv = tf.add(l5_conv, self.get_bias(5, 1.0)) l5_conv = tf.nn.relu(l5_conv) # Layer 5 Max Pooling layer l5_MP = tf.layers.max_pooling2d(l5_conv, self.hyper_param['L5_MP']['filter_size'], self.hyper_param['L5_MP']['stride'], name='L5_MP') flatten = tf.layers.flatten(l5_MP) # Layer 6 Fully connected layer l6_FC = tf.contrib.layers.fully_connected(flatten, self.hyper_param['FC6']) # Dropout layer l6_dropout = tf.nn.dropout(l6_FC, self.dropout, name='l6_dropout') # Layer 7 Fully connected layer self.l7_FC = tf.contrib.layers.fully_connected(l6_dropout, self.hyper_param['FC7']) # Dropout layer l7_dropout = tf.nn.dropout(self.l7_FC, self.dropout, name='l7_dropout') # final layer before softmax self.logits = tf.contrib.layers.fully_connected(l7_dropout, self.num_classes, None) # loss function loss_function = tf.nn.softmax_cross_entropy_with_logits( logits = self.logits, labels = self.labels ) # total loss self.loss = tf.reduce_mean(loss_function) + self.l2_loss self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, momentum=self.momentum)\ .minimize(self.loss, global_step=self.global_step) correct = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.labels, 1)) self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) self.top5_correct = tf.nn.in_top_k(self.logits, tf.argmax(self.labels, 1), 5) self.top5_accuracy = tf.reduce_mean(tf.cast(self.top5_correct, tf.float32)) self.add_summaries() def add_summaries(self): """ Add summaries for loss, top1 and top5 accuracies Add loss, top1 and top5 accuracies to summary files in order to visualize in tensorboard """ tf.summary.scalar('loss', self.loss) tf.summary.scalar('Top-1-Acc', self.accuracy) tf.summary.scalar('Top-5-Acc', self.top5_accuracy) self.merged = tf.summary.merge_all() def save_model(self, sess, saver): """ Save the current model :param sess: Session object :param saver: Saver object responsible to store """ model_base_path = os.path.join(os.getcwd(), 'model') if not os.path.exists(model_base_path): os.mkdir(model_base_path) model_save_path = os.path.join(os.getcwd(), 'model', 'model.ckpt') save_path = saver.save(sess, model_save_path) self.logger.info("Model saved in path: %s", save_path) def restore_model(self, sess, saver): """ Restore previously saved model :param sess: Session object :param saver: Saver object responsible to store """ model_base_path = os.path.join(os.getcwd(), 'model') model_restore_path = os.path.join(os.getcwd(), 'model', 'model.ckpt') saver.restore(sess, model_restore_path) self.logger.info("Model Restored from path: %s", model_restore_path) def get_summary_writer(self, sess): """ Get summary writer for training and validation Responsible for creating summary writer so it can write summaries to a file so it can be read by tensorboard later. """ if not os.path.exists(os.path.join('summary', 'train')): os.makedirs(os.path.join('summary', 'train')) if not os.path.exists(os.path.join('summary', 'val')): os.makedirs(os.path.join('summary', 'val')) return (tf.summary.FileWriter(os.path.join(os.getcwd(), 'summary', 'train'), sess.graph), tf.summary.FileWriter(os.path.join(os.getcwd(), 'summary', 'val'), sess.graph)) def train(self, epochs, thread='false'): """ Train AlexNet. """ batch_step, val_step = 10, 500 self.logger.info("Building the graph...") self.build_graph() init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: (summary_writer_train, summary_writer_val) = self.get_summary_writer(sess) if self.resume and os.path.exists(os.path.join(os.getcwd(), 'model')): self.restore_model(sess, saver) else: sess.run(init) resume_batch = True best_loss = float('inf') while sess.run(self.cur_epoch) < epochs: losses = [] accuracies = [] epoch = sess.run(self.cur_epoch) if not self.resume or ( self.resume and not resume_batch): sess.run(self.init_batch_op) resume_batch = False start = time.time() gen_batch = self.lsvrc2010.gen_batch for images, labels in gen_batch: batch_i = sess.run(self.cur_batch) # If it's resumed from stored model, # this will save from messing up the batch number # in subsequent epoch if batch_i >= ceil(len(self.lsvrc2010.image_names) / self.batch_size): break (_, global_step, _) = sess.run([self.optimizer, self.global_step, self.increment_batch_op], feed_dict = { self.input_image: images, self.labels: labels, self.learning_rate: self.lr, self.dropout: 0.5 }) if global_step == 150000: self.lr = 0.0001 # Halve the learning rate if batch_i % batch_step == 0: (summary, loss, acc, top5_acc, _top5, logits, l7_FC) = sess.run([self.merged, self.loss, self.accuracy, self.top5_accuracy, self.top5_correct, self.logits, self.l7_FC], feed_dict = { self.input_image: images, self.labels: labels, self.learning_rate: self.lr, self.dropout: 1.0 }) losses.append(loss) accuracies.append(acc) summary_writer_train.add_summary(summary, global_step) summary_writer_train.flush() end = time.time() try: self.logger.debug("l7 no of non zeros: %d", np.count_nonzero(l7_FC)) true_idx = np.where(_top5[0]==True)[0][0] self.logger.debug("logit at %d: %s", true_idx, str(logits[true_idx])) except IndexError as ie: self.logger.debug(ie) self.logger.info("Time: %f Epoch: %d Batch: %d Loss: %f " "Avg loss: %f Accuracy: %f Avg Accuracy: %f " "Top 5 Accuracy: %f", end - start, epoch, batch_i, loss, sum(losses) / len(losses), acc, sum(accuracies) / len(accuracies), top5_acc) start = time.time() if batch_i % val_step == 0: images_val, labels_val = self.lsvrc2010.get_batch_val (summary, acc, top5_acc, loss) = sess.run([self.merged, self.accuracy, self.top5_accuracy, self.loss], feed_dict = { self.input_image: images_val, self.labels: labels_val, self.learning_rate: self.lr, self.dropout: 1.0 }) summary_writer_val.add_summary(summary, global_step) summary_writer_val.flush() self.logger.info("Validation - Accuracy: %f Top 5 Accuracy: %f Loss: %f", acc, top5_acc, loss) cur_loss = sum(losses) / len(losses) if cur_loss < best_loss: best_loss = cur_loss self.save_model(sess, saver) # Increase epoch number sess.run(self.increment_epoch_op) def test(self): step = 10 self.logger_test = logs.get_logger('AlexNetTest', file_name='logs_test.log') self.logger_test.info("In Test: Building the graph...") self.build_graph() init = tf.global_variables_initializer() saver = tf.train.Saver() top1_count, top5_count, count = 0, 0, 0 with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: self.restore_model(sess, saver) start = time.time() batch = self.lsvrc2010.gen_batch_test for i, (patches, labels) in enumerate(batch): count += patches[0].shape[0] avg_logits = np.zeros((patches[0].shape[0], self.num_classes)) for patch in patches: logits = sess.run(self.logits, feed_dict = { self.input_image: patch, self.dropout: 1.0 }) avg_logits += logits avg_logits /= len(patches) top1_count += np.sum(np.argmax(avg_logits, 1) == labels) top5_count += np.sum(avg_logits.argsort()[:, -5:] == \ np.repeat(labels, 5).reshape(patches[0].shape[0], 5)) if i % step == 0: end = time.time() self.logger_test.info("Time: %f Step: %d " "Avg Accuracy: %f " "Avg Top 5 Accuracy: %f", end - start, i, top1_count / count, top5_count / count) start = time.time() self.logger_test.info("Final - Avg Accuracy: %f " "Avg Top 5 Accuracy: %f", top1_count / count, top5_count / count) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('image_path', metavar = 'image-path', help = 'ImageNet dataset path') parser.add_argument('--resume', metavar='resume', type=lambda x: x != 'False', default=True, required=False, help='Resume training (True or False)') parser.add_argument('--train', help='Train AlexNet') parser.add_argument('--test', help='Test AlexNet') args = parser.parse_args() alexnet = AlexNet(args.image_path, batch_size=128, resume=args.resume) if args.train == 'true': alexnet.train(50) elif args.test == 'true': alexnet.test()