|
|
@ -0,0 +1,514 @@ |
|
|
|
#!/usr/bin/python3 |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
|
|
# library modules |
|
|
|
from math import ceil |
|
|
|
import json |
|
|
|
import time |
|
|
|
import os |
|
|
|
import threading |
|
|
|
|
|
|
|
# External library modules |
|
|
|
import tensorflow as tf |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
# local modules |
|
|
|
from data import LSVRC2010 |
|
|
|
import logs |
|
|
|
|
|
|
|
class AlexNet: |
|
|
|
""" |
|
|
|
A tensorflow implementation of the paper: |
|
|
|
`AlexNet <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_ |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, path, batch_size, resume): |
|
|
|
""" |
|
|
|
Build the AlexNet model |
|
|
|
""" |
|
|
|
self.logger = logs.get_logger() |
|
|
|
|
|
|
|
self.resume = resume |
|
|
|
self.path = path |
|
|
|
self.batch_size = batch_size |
|
|
|
self.lsvrc2010 = LSVRC2010(self.path, batch_size) |
|
|
|
self.num_classes = len(self.lsvrc2010.wnid2label) |
|
|
|
|
|
|
|
self.lr = 0.001 |
|
|
|
self.momentum = 0.9 |
|
|
|
self.lambd = tf.constant(0.0005, name='lambda') |
|
|
|
self.input_shape = (None, 227, 227, 3) |
|
|
|
self.output_shape = (None, self.num_classes) |
|
|
|
|
|
|
|
self.logger.info("Creating placeholders for graph...") |
|
|
|
self.create_tf_placeholders() |
|
|
|
|
|
|
|
self.logger.info("Creating variables for graph...") |
|
|
|
self.create_tf_variables() |
|
|
|
|
|
|
|
self.logger.info("Initialize hyper parameters...") |
|
|
|
self.hyper_param = {} |
|
|
|
self.init_hyper_param() |
|
|
|
|
|
|
|
def create_tf_placeholders(self): |
|
|
|
""" |
|
|
|
Create placeholders for the graph. |
|
|
|
The input for these will be given while training or testing. |
|
|
|
""" |
|
|
|
self.input_image = tf.placeholder(tf.float32, shape=self.input_shape, |
|
|
|
name='input_image') |
|
|
|
self.labels = tf.placeholder(tf.float32, shape=self.output_shape, |
|
|
|
name='output') |
|
|
|
self.learning_rate = tf.placeholder(tf.float32, shape=(), |
|
|
|
name='learning_rate') |
|
|
|
self.dropout = tf.placeholder(tf.float32, shape=(), |
|
|
|
name='dropout') |
|
|
|
|
|
|
|
def create_tf_variables(self): |
|
|
|
""" |
|
|
|
Create variables for epoch, batch and global step |
|
|
|
""" |
|
|
|
self.global_step = tf.Variable(0, name='global_step', trainable=False) |
|
|
|
self.cur_epoch = tf.Variable(0, name='epoch', trainable=False) |
|
|
|
self.cur_batch = tf.Variable(0, name='batch', trainable=False) |
|
|
|
|
|
|
|
self.increment_epoch_op = tf.assign(self.cur_epoch, self.cur_epoch+1) |
|
|
|
self.increment_batch_op = tf.assign(self.cur_batch, self.cur_batch+1) |
|
|
|
self.init_batch_op = tf.assign(self.cur_batch, 0) |
|
|
|
|
|
|
|
def init_hyper_param(self): |
|
|
|
""" |
|
|
|
Store the hyper parameters. |
|
|
|
For each layer store number of filters(kernels) |
|
|
|
and filter size. |
|
|
|
If it's a fully connected layer then store the number of neurons. |
|
|
|
""" |
|
|
|
with open('hparam.json') as f: |
|
|
|
self.hyper_param = json.load(f) |
|
|
|
|
|
|
|
def get_filter(self, layer_num, layer_name): |
|
|
|
""" |
|
|
|
:param layer_num: Indicates the layer number in the graph |
|
|
|
:type layer_num: int |
|
|
|
:param layer_name: Name of the filter |
|
|
|
""" |
|
|
|
layer = 'L' + str(layer_num) |
|
|
|
|
|
|
|
filter_height, filter_width, in_channels = self.hyper_param[layer]['filter_size'] |
|
|
|
out_channels = self.hyper_param[layer]['filters'] |
|
|
|
|
|
|
|
return tf.Variable(tf.truncated_normal( |
|
|
|
[filter_height, filter_width, in_channels, out_channels], |
|
|
|
dtype = tf.float32, stddev = 1e-2), name = layer_name) |
|
|
|
|
|
|
|
def get_strides(self, layer_num): |
|
|
|
""" |
|
|
|
:param layer_num: Indicates the layer number in the graph |
|
|
|
:type layer_num: int |
|
|
|
""" |
|
|
|
layer = 'L' + str(layer_num) |
|
|
|
|
|
|
|
stride = self.hyper_param[layer]['stride'] |
|
|
|
strides = [1, stride, stride, 1] |
|
|
|
|
|
|
|
return strides |
|
|
|
|
|
|
|
def get_bias(self, layer_num, value=0.0): |
|
|
|
""" |
|
|
|
Get the bias variable for current layer |
|
|
|
|
|
|
|
:param layer_num: Indicates the layer number in the graph |
|
|
|
:type layer_num: int |
|
|
|
""" |
|
|
|
layer = 'L' + str(layer_num) |
|
|
|
initial = tf.constant(value, |
|
|
|
shape=[self.hyper_param[layer]['filters']], |
|
|
|
name='C' + str(layer_num)) |
|
|
|
return tf.Variable(initial, name='B' + str(layer_num)) |
|
|
|
|
|
|
|
@property |
|
|
|
def l2_loss(self): |
|
|
|
""" |
|
|
|
Compute the l2 loss for all the weights |
|
|
|
""" |
|
|
|
conv_bias_names = ['B' + str(i) for i in range(1, 6)] |
|
|
|
weights = [] |
|
|
|
for v in tf.trainable_variables(): |
|
|
|
if 'biases' in v.name: continue |
|
|
|
if v.name.split(':')[0] in conv_bias_names: continue |
|
|
|
weights.append(v) |
|
|
|
|
|
|
|
return self.lambd * tf.add_n([tf.nn.l2_loss(weight) for weight in weights]) |
|
|
|
|
|
|
|
def build_graph(self): |
|
|
|
""" |
|
|
|
Build the tensorflow graph for AlexNet. |
|
|
|
|
|
|
|
First 5 layers are Convolutional layers. Out of which |
|
|
|
first 2 and last layer will be followed by *max pooling* |
|
|
|
layers. |
|
|
|
|
|
|
|
Next 2 layers are fully connected layers. |
|
|
|
|
|
|
|
L1_conv -> L1_MP -> L2_conv -> L2_MP -> L3_conv |
|
|
|
-> L4_conv -> L5_conv -> L5_MP -> L6_FC -> L7_FC |
|
|
|
|
|
|
|
Where L1_conv -> Convolutional layer 1 |
|
|
|
L5_MP -> Max pooling layer 5 |
|
|
|
L7_FC -> Fully Connected layer 7 |
|
|
|
|
|
|
|
Use `tf.nn.conv2d` to initialize the filters so |
|
|
|
as to reduce training time and `tf.layers.max_pooling2d` |
|
|
|
as we don't need to initialize in the pooling layer. |
|
|
|
""" |
|
|
|
# Layer 1 Convolutional layer |
|
|
|
filter1 = self.get_filter(1, 'L1_filter') |
|
|
|
l1_conv = tf.nn.conv2d(self.input_image, filter1, |
|
|
|
self.get_strides(1), |
|
|
|
padding = self.hyper_param['L1']['padding'], |
|
|
|
name='L1_conv') |
|
|
|
l1_conv = tf.add(l1_conv, self.get_bias(1)) |
|
|
|
l1_conv = tf.nn.local_response_normalization(l1_conv, |
|
|
|
depth_radius=5, |
|
|
|
bias=2, |
|
|
|
alpha=1e-4, |
|
|
|
beta=.75) |
|
|
|
l1_conv = tf.nn.relu(l1_conv) |
|
|
|
|
|
|
|
# Layer 1 Max Pooling layer |
|
|
|
l1_MP = tf.layers.max_pooling2d(l1_conv, |
|
|
|
self.hyper_param['L1_MP']['filter_size'], |
|
|
|
self.hyper_param['L1_MP']['stride'], |
|
|
|
name='L1_MP') |
|
|
|
|
|
|
|
# Layer 2 Convolutional layer |
|
|
|
filter2 = self.get_filter(2, 'L2_filter') |
|
|
|
l2_conv = tf.nn.conv2d(l1_MP, filter2, |
|
|
|
self.get_strides(2), |
|
|
|
padding = self.hyper_param['L2']['padding'], |
|
|
|
name='L2_conv') |
|
|
|
l2_conv = tf.add(l2_conv, self.get_bias(2, 1.0)) |
|
|
|
l2_conv = tf.nn.local_response_normalization(l2_conv, |
|
|
|
depth_radius=5, |
|
|
|
bias=2, |
|
|
|
alpha=1e-4, |
|
|
|
beta=.75) |
|
|
|
l2_conv = tf.nn.relu(l2_conv) |
|
|
|
|
|
|
|
# Layer 2 Max Pooling layer |
|
|
|
l2_MP = tf.layers.max_pooling2d(l2_conv, |
|
|
|
self.hyper_param['L2_MP']['filter_size'], |
|
|
|
self.hyper_param['L2_MP']['stride'], |
|
|
|
name='L2_MP') |
|
|
|
|
|
|
|
# Layer 3 Convolutional layer |
|
|
|
filter3 = self.get_filter(3, 'L3_filter') |
|
|
|
l3_conv = tf.nn.conv2d(l2_MP, filter3, |
|
|
|
self.get_strides(3), |
|
|
|
padding = self.hyper_param['L3']['padding'], |
|
|
|
name='L3_conv') |
|
|
|
l3_conv = tf.add(l3_conv, self.get_bias(3)) |
|
|
|
l3_conv = tf.nn.relu(l3_conv) |
|
|
|
|
|
|
|
# Layer 4 Convolutional layer |
|
|
|
filter4 = self.get_filter(4, 'L4_filter') |
|
|
|
l4_conv = tf.nn.conv2d(l3_conv, filter4, |
|
|
|
self.get_strides(4), |
|
|
|
padding = self.hyper_param['L4']['padding'], |
|
|
|
name='L4_conv') |
|
|
|
l4_conv = tf.add(l4_conv, self.get_bias(4, 1.0)) |
|
|
|
l4_conv = tf.nn.relu(l4_conv) |
|
|
|
|
|
|
|
# Layer 5 Convolutional layer |
|
|
|
filter5 = self.get_filter(5, 'L5_filter') |
|
|
|
l5_conv = tf.nn.conv2d(l4_conv, filter5, |
|
|
|
self.get_strides(5), |
|
|
|
padding = self.hyper_param['L5']['padding'], |
|
|
|
name='L5_conv') |
|
|
|
l5_conv = tf.add(l5_conv, self.get_bias(5, 1.0)) |
|
|
|
l5_conv = tf.nn.relu(l5_conv) |
|
|
|
|
|
|
|
# Layer 5 Max Pooling layer |
|
|
|
l5_MP = tf.layers.max_pooling2d(l5_conv, |
|
|
|
self.hyper_param['L5_MP']['filter_size'], |
|
|
|
self.hyper_param['L5_MP']['stride'], |
|
|
|
name='L5_MP') |
|
|
|
|
|
|
|
flatten = tf.layers.flatten(l5_MP) |
|
|
|
|
|
|
|
# Layer 6 Fully connected layer |
|
|
|
l6_FC = tf.contrib.layers.fully_connected(flatten, |
|
|
|
self.hyper_param['FC6']) |
|
|
|
|
|
|
|
# Dropout layer |
|
|
|
l6_dropout = tf.nn.dropout(l6_FC, self.dropout, |
|
|
|
name='l6_dropout') |
|
|
|
|
|
|
|
# Layer 7 Fully connected layer |
|
|
|
self.l7_FC = tf.contrib.layers.fully_connected(l6_dropout, |
|
|
|
self.hyper_param['FC7']) |
|
|
|
|
|
|
|
# Dropout layer |
|
|
|
l7_dropout = tf.nn.dropout(self.l7_FC, self.dropout, |
|
|
|
name='l7_dropout') |
|
|
|
|
|
|
|
# final layer before softmax |
|
|
|
self.logits = tf.contrib.layers.fully_connected(l7_dropout, |
|
|
|
self.num_classes, None) |
|
|
|
|
|
|
|
# loss function |
|
|
|
loss_function = tf.nn.softmax_cross_entropy_with_logits( |
|
|
|
logits = self.logits, |
|
|
|
labels = self.labels |
|
|
|
) |
|
|
|
|
|
|
|
# total loss |
|
|
|
self.loss = tf.reduce_mean(loss_function) + self.l2_loss |
|
|
|
|
|
|
|
self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, momentum=self.momentum)\ |
|
|
|
.minimize(self.loss, global_step=self.global_step) |
|
|
|
|
|
|
|
correct = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.labels, 1)) |
|
|
|
self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) |
|
|
|
|
|
|
|
self.top5_correct = tf.nn.in_top_k(self.logits, tf.argmax(self.labels, 1), 5) |
|
|
|
self.top5_accuracy = tf.reduce_mean(tf.cast(self.top5_correct, tf.float32)) |
|
|
|
|
|
|
|
self.add_summaries() |
|
|
|
|
|
|
|
def add_summaries(self): |
|
|
|
""" |
|
|
|
Add summaries for loss, top1 and top5 accuracies |
|
|
|
|
|
|
|
Add loss, top1 and top5 accuracies to summary files |
|
|
|
in order to visualize in tensorboard |
|
|
|
""" |
|
|
|
tf.summary.scalar('loss', self.loss) |
|
|
|
tf.summary.scalar('Top-1-Acc', self.accuracy) |
|
|
|
tf.summary.scalar('Top-5-Acc', self.top5_accuracy) |
|
|
|
|
|
|
|
self.merged = tf.summary.merge_all() |
|
|
|
|
|
|
|
def save_model(self, sess, saver): |
|
|
|
""" |
|
|
|
Save the current model |
|
|
|
|
|
|
|
:param sess: Session object |
|
|
|
:param saver: Saver object responsible to store |
|
|
|
""" |
|
|
|
model_base_path = os.path.join(os.getcwd(), 'model') |
|
|
|
if not os.path.exists(model_base_path): |
|
|
|
os.mkdir(model_base_path) |
|
|
|
model_save_path = os.path.join(os.getcwd(), 'model', 'model.ckpt') |
|
|
|
save_path = saver.save(sess, model_save_path) |
|
|
|
self.logger.info("Model saved in path: %s", save_path) |
|
|
|
|
|
|
|
def restore_model(self, sess, saver): |
|
|
|
""" |
|
|
|
Restore previously saved model |
|
|
|
|
|
|
|
:param sess: Session object |
|
|
|
:param saver: Saver object responsible to store |
|
|
|
""" |
|
|
|
model_base_path = os.path.join(os.getcwd(), 'model') |
|
|
|
model_restore_path = os.path.join(os.getcwd(), 'model', 'model.ckpt') |
|
|
|
saver.restore(sess, model_restore_path) |
|
|
|
self.logger.info("Model Restored from path: %s", |
|
|
|
model_restore_path) |
|
|
|
|
|
|
|
def get_summary_writer(self, sess): |
|
|
|
""" |
|
|
|
Get summary writer for training and validation |
|
|
|
|
|
|
|
Responsible for creating summary writer so it can |
|
|
|
write summaries to a file so it can be read by |
|
|
|
tensorboard later. |
|
|
|
""" |
|
|
|
if not os.path.exists(os.path.join('summary', 'train')): |
|
|
|
os.makedirs(os.path.join('summary', 'train')) |
|
|
|
if not os.path.exists(os.path.join('summary', 'val')): |
|
|
|
os.makedirs(os.path.join('summary', 'val')) |
|
|
|
return (tf.summary.FileWriter(os.path.join(os.getcwd(), |
|
|
|
'summary', 'train'), |
|
|
|
sess.graph), |
|
|
|
tf.summary.FileWriter(os.path.join(os.getcwd(), |
|
|
|
'summary', 'val'), |
|
|
|
sess.graph)) |
|
|
|
|
|
|
|
def train(self, epochs, thread='false'): |
|
|
|
""" |
|
|
|
Train AlexNet. |
|
|
|
""" |
|
|
|
batch_step, val_step = 10, 500 |
|
|
|
|
|
|
|
self.logger.info("Building the graph...") |
|
|
|
self.build_graph() |
|
|
|
|
|
|
|
init = tf.global_variables_initializer() |
|
|
|
|
|
|
|
saver = tf.train.Saver() |
|
|
|
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: |
|
|
|
(summary_writer_train, |
|
|
|
summary_writer_val) = self.get_summary_writer(sess) |
|
|
|
if self.resume and os.path.exists(os.path.join(os.getcwd(), |
|
|
|
'model')): |
|
|
|
self.restore_model(sess, saver) |
|
|
|
else: |
|
|
|
sess.run(init) |
|
|
|
|
|
|
|
resume_batch = True |
|
|
|
best_loss = float('inf') |
|
|
|
while sess.run(self.cur_epoch) < epochs: |
|
|
|
losses = [] |
|
|
|
accuracies = [] |
|
|
|
|
|
|
|
epoch = sess.run(self.cur_epoch) |
|
|
|
if not self.resume or ( |
|
|
|
self.resume and not resume_batch): |
|
|
|
sess.run(self.init_batch_op) |
|
|
|
resume_batch = False |
|
|
|
start = time.time() |
|
|
|
gen_batch = self.lsvrc2010.gen_batch |
|
|
|
for images, labels in gen_batch: |
|
|
|
batch_i = sess.run(self.cur_batch) |
|
|
|
# If it's resumed from stored model, |
|
|
|
# this will save from messing up the batch number |
|
|
|
# in subsequent epoch |
|
|
|
if batch_i >= ceil(len(self.lsvrc2010.image_names) / self.batch_size): |
|
|
|
break |
|
|
|
(_, global_step, |
|
|
|
_) = sess.run([self.optimizer, |
|
|
|
self.global_step, self.increment_batch_op], |
|
|
|
feed_dict = { |
|
|
|
self.input_image: images, |
|
|
|
self.labels: labels, |
|
|
|
self.learning_rate: self.lr, |
|
|
|
self.dropout: 0.5 |
|
|
|
}) |
|
|
|
|
|
|
|
if global_step == 150000: |
|
|
|
self.lr = 0.0001 # Halve the learning rate |
|
|
|
|
|
|
|
if batch_i % batch_step == 0: |
|
|
|
(summary, loss, acc, top5_acc, _top5, |
|
|
|
logits, l7_FC) = sess.run([self.merged, self.loss, |
|
|
|
self.accuracy, self.top5_accuracy, |
|
|
|
self.top5_correct, |
|
|
|
self.logits, self.l7_FC], |
|
|
|
feed_dict = { |
|
|
|
self.input_image: images, |
|
|
|
self.labels: labels, |
|
|
|
self.learning_rate: self.lr, |
|
|
|
self.dropout: 1.0 |
|
|
|
}) |
|
|
|
losses.append(loss) |
|
|
|
accuracies.append(acc) |
|
|
|
summary_writer_train.add_summary(summary, global_step) |
|
|
|
summary_writer_train.flush() |
|
|
|
end = time.time() |
|
|
|
try: |
|
|
|
self.logger.debug("l7 no of non zeros: %d", np.count_nonzero(l7_FC)) |
|
|
|
true_idx = np.where(_top5[0]==True)[0][0] |
|
|
|
self.logger.debug("logit at %d: %s", true_idx, |
|
|
|
str(logits[true_idx])) |
|
|
|
except IndexError as ie: |
|
|
|
self.logger.debug(ie) |
|
|
|
self.logger.info("Time: %f Epoch: %d Batch: %d Loss: %f " |
|
|
|
"Avg loss: %f Accuracy: %f Avg Accuracy: %f " |
|
|
|
"Top 5 Accuracy: %f", |
|
|
|
end - start, epoch, batch_i, |
|
|
|
loss, sum(losses) / len(losses), |
|
|
|
acc, sum(accuracies) / len(accuracies), |
|
|
|
top5_acc) |
|
|
|
start = time.time() |
|
|
|
|
|
|
|
if batch_i % val_step == 0: |
|
|
|
images_val, labels_val = self.lsvrc2010.get_batch_val |
|
|
|
(summary, acc, top5_acc, |
|
|
|
loss) = sess.run([self.merged, |
|
|
|
self.accuracy, |
|
|
|
self.top5_accuracy, self.loss], |
|
|
|
feed_dict = { |
|
|
|
self.input_image: images_val, |
|
|
|
self.labels: labels_val, |
|
|
|
self.learning_rate: self.lr, |
|
|
|
self.dropout: 1.0 |
|
|
|
}) |
|
|
|
summary_writer_val.add_summary(summary, global_step) |
|
|
|
summary_writer_val.flush() |
|
|
|
self.logger.info("Validation - Accuracy: %f Top 5 Accuracy: %f Loss: %f", |
|
|
|
acc, top5_acc, loss) |
|
|
|
|
|
|
|
cur_loss = sum(losses) / len(losses) |
|
|
|
if cur_loss < best_loss: |
|
|
|
best_loss = cur_loss |
|
|
|
self.save_model(sess, saver) |
|
|
|
|
|
|
|
# Increase epoch number |
|
|
|
sess.run(self.increment_epoch_op) |
|
|
|
|
|
|
|
def test(self): |
|
|
|
step = 10 |
|
|
|
|
|
|
|
self.logger_test = logs.get_logger('AlexNetTest', file_name='logs_test.log') |
|
|
|
self.logger_test.info("In Test: Building the graph...") |
|
|
|
self.build_graph() |
|
|
|
|
|
|
|
init = tf.global_variables_initializer() |
|
|
|
|
|
|
|
saver = tf.train.Saver() |
|
|
|
top1_count, top5_count, count = 0, 0, 0 |
|
|
|
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: |
|
|
|
self.restore_model(sess, saver) |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
batch = self.lsvrc2010.gen_batch_test |
|
|
|
for i, (patches, labels) in enumerate(batch): |
|
|
|
count += patches[0].shape[0] |
|
|
|
avg_logits = np.zeros((patches[0].shape[0], self.num_classes)) |
|
|
|
for patch in patches: |
|
|
|
logits = sess.run(self.logits, |
|
|
|
feed_dict = { |
|
|
|
self.input_image: patch, |
|
|
|
self.dropout: 1.0 |
|
|
|
}) |
|
|
|
avg_logits += logits |
|
|
|
avg_logits /= len(patches) |
|
|
|
top1_count += np.sum(np.argmax(avg_logits, 1) == labels) |
|
|
|
top5_count += np.sum(avg_logits.argsort()[:, -5:] == \ |
|
|
|
np.repeat(labels, 5).reshape(patches[0].shape[0], 5)) |
|
|
|
|
|
|
|
if i % step == 0: |
|
|
|
end = time.time() |
|
|
|
self.logger_test.info("Time: %f Step: %d " |
|
|
|
"Avg Accuracy: %f " |
|
|
|
"Avg Top 5 Accuracy: %f", |
|
|
|
end - start, i, |
|
|
|
top1_count / count, |
|
|
|
top5_count / count) |
|
|
|
start = time.time() |
|
|
|
|
|
|
|
self.logger_test.info("Final - Avg Accuracy: %f " |
|
|
|
"Avg Top 5 Accuracy: %f", |
|
|
|
top1_count / count, |
|
|
|
top5_count / count) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('image_path', metavar = 'image-path', |
|
|
|
help = 'ImageNet dataset path') |
|
|
|
parser.add_argument('--resume', metavar='resume', |
|
|
|
type=lambda x: x != 'False', default=True, |
|
|
|
required=False, |
|
|
|
help='Resume training (True or False)') |
|
|
|
parser.add_argument('--train', help='Train AlexNet') |
|
|
|
parser.add_argument('--test', help='Test AlexNet') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
alexnet = AlexNet(args.image_path, batch_size=128, resume=args.resume) |
|
|
|
|
|
|
|
if args.train == 'true': |
|
|
|
alexnet.train(50) |
|
|
|
elif args.test == 'true': |
|
|
|
alexnet.test() |