云计算课程实验
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

514 lines
21 KiB

3 years ago
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # library modules
  4. from math import ceil
  5. import json
  6. import time
  7. import os
  8. import threading
  9. # External library modules
  10. import tensorflow as tf
  11. import numpy as np
  12. # local modules
  13. from data import LSVRC2010
  14. import logs
  15. class AlexNet:
  16. """
  17. A tensorflow implementation of the paper:
  18. `AlexNet <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_
  19. """
  20. def __init__(self, path, batch_size, resume):
  21. """
  22. Build the AlexNet model
  23. """
  24. self.logger = logs.get_logger()
  25. self.resume = resume
  26. self.path = path
  27. self.batch_size = batch_size
  28. self.lsvrc2010 = LSVRC2010(self.path, batch_size)
  29. self.num_classes = len(self.lsvrc2010.wnid2label)
  30. self.lr = 0.001
  31. self.momentum = 0.9
  32. self.lambd = tf.constant(0.0005, name='lambda')
  33. self.input_shape = (None, 227, 227, 3)
  34. self.output_shape = (None, self.num_classes)
  35. self.logger.info("Creating placeholders for graph...")
  36. self.create_tf_placeholders()
  37. self.logger.info("Creating variables for graph...")
  38. self.create_tf_variables()
  39. self.logger.info("Initialize hyper parameters...")
  40. self.hyper_param = {}
  41. self.init_hyper_param()
  42. def create_tf_placeholders(self):
  43. """
  44. Create placeholders for the graph.
  45. The input for these will be given while training or testing.
  46. """
  47. self.input_image = tf.placeholder(tf.float32, shape=self.input_shape,
  48. name='input_image')
  49. self.labels = tf.placeholder(tf.float32, shape=self.output_shape,
  50. name='output')
  51. self.learning_rate = tf.placeholder(tf.float32, shape=(),
  52. name='learning_rate')
  53. self.dropout = tf.placeholder(tf.float32, shape=(),
  54. name='dropout')
  55. def create_tf_variables(self):
  56. """
  57. Create variables for epoch, batch and global step
  58. """
  59. self.global_step = tf.Variable(0, name='global_step', trainable=False)
  60. self.cur_epoch = tf.Variable(0, name='epoch', trainable=False)
  61. self.cur_batch = tf.Variable(0, name='batch', trainable=False)
  62. self.increment_epoch_op = tf.assign(self.cur_epoch, self.cur_epoch+1)
  63. self.increment_batch_op = tf.assign(self.cur_batch, self.cur_batch+1)
  64. self.init_batch_op = tf.assign(self.cur_batch, 0)
  65. def init_hyper_param(self):
  66. """
  67. Store the hyper parameters.
  68. For each layer store number of filters(kernels)
  69. and filter size.
  70. If it's a fully connected layer then store the number of neurons.
  71. """
  72. with open('hparam.json') as f:
  73. self.hyper_param = json.load(f)
  74. def get_filter(self, layer_num, layer_name):
  75. """
  76. :param layer_num: Indicates the layer number in the graph
  77. :type layer_num: int
  78. :param layer_name: Name of the filter
  79. """
  80. layer = 'L' + str(layer_num)
  81. filter_height, filter_width, in_channels = self.hyper_param[layer]['filter_size']
  82. out_channels = self.hyper_param[layer]['filters']
  83. return tf.Variable(tf.truncated_normal(
  84. [filter_height, filter_width, in_channels, out_channels],
  85. dtype = tf.float32, stddev = 1e-2), name = layer_name)
  86. def get_strides(self, layer_num):
  87. """
  88. :param layer_num: Indicates the layer number in the graph
  89. :type layer_num: int
  90. """
  91. layer = 'L' + str(layer_num)
  92. stride = self.hyper_param[layer]['stride']
  93. strides = [1, stride, stride, 1]
  94. return strides
  95. def get_bias(self, layer_num, value=0.0):
  96. """
  97. Get the bias variable for current layer
  98. :param layer_num: Indicates the layer number in the graph
  99. :type layer_num: int
  100. """
  101. layer = 'L' + str(layer_num)
  102. initial = tf.constant(value,
  103. shape=[self.hyper_param[layer]['filters']],
  104. name='C' + str(layer_num))
  105. return tf.Variable(initial, name='B' + str(layer_num))
  106. @property
  107. def l2_loss(self):
  108. """
  109. Compute the l2 loss for all the weights
  110. """
  111. conv_bias_names = ['B' + str(i) for i in range(1, 6)]
  112. weights = []
  113. for v in tf.trainable_variables():
  114. if 'biases' in v.name: continue
  115. if v.name.split(':')[0] in conv_bias_names: continue
  116. weights.append(v)
  117. return self.lambd * tf.add_n([tf.nn.l2_loss(weight) for weight in weights])
  118. def build_graph(self):
  119. """
  120. Build the tensorflow graph for AlexNet.
  121. First 5 layers are Convolutional layers. Out of which
  122. first 2 and last layer will be followed by *max pooling*
  123. layers.
  124. Next 2 layers are fully connected layers.
  125. L1_conv -> L1_MP -> L2_conv -> L2_MP -> L3_conv
  126. -> L4_conv -> L5_conv -> L5_MP -> L6_FC -> L7_FC
  127. Where L1_conv -> Convolutional layer 1
  128. L5_MP -> Max pooling layer 5
  129. L7_FC -> Fully Connected layer 7
  130. Use `tf.nn.conv2d` to initialize the filters so
  131. as to reduce training time and `tf.layers.max_pooling2d`
  132. as we don't need to initialize in the pooling layer.
  133. """
  134. # Layer 1 Convolutional layer
  135. filter1 = self.get_filter(1, 'L1_filter')
  136. l1_conv = tf.nn.conv2d(self.input_image, filter1,
  137. self.get_strides(1),
  138. padding = self.hyper_param['L1']['padding'],
  139. name='L1_conv')
  140. l1_conv = tf.add(l1_conv, self.get_bias(1))
  141. l1_conv = tf.nn.local_response_normalization(l1_conv,
  142. depth_radius=5,
  143. bias=2,
  144. alpha=1e-4,
  145. beta=.75)
  146. l1_conv = tf.nn.relu(l1_conv)
  147. # Layer 1 Max Pooling layer
  148. l1_MP = tf.layers.max_pooling2d(l1_conv,
  149. self.hyper_param['L1_MP']['filter_size'],
  150. self.hyper_param['L1_MP']['stride'],
  151. name='L1_MP')
  152. # Layer 2 Convolutional layer
  153. filter2 = self.get_filter(2, 'L2_filter')
  154. l2_conv = tf.nn.conv2d(l1_MP, filter2,
  155. self.get_strides(2),
  156. padding = self.hyper_param['L2']['padding'],
  157. name='L2_conv')
  158. l2_conv = tf.add(l2_conv, self.get_bias(2, 1.0))
  159. l2_conv = tf.nn.local_response_normalization(l2_conv,
  160. depth_radius=5,
  161. bias=2,
  162. alpha=1e-4,
  163. beta=.75)
  164. l2_conv = tf.nn.relu(l2_conv)
  165. # Layer 2 Max Pooling layer
  166. l2_MP = tf.layers.max_pooling2d(l2_conv,
  167. self.hyper_param['L2_MP']['filter_size'],
  168. self.hyper_param['L2_MP']['stride'],
  169. name='L2_MP')
  170. # Layer 3 Convolutional layer
  171. filter3 = self.get_filter(3, 'L3_filter')
  172. l3_conv = tf.nn.conv2d(l2_MP, filter3,
  173. self.get_strides(3),
  174. padding = self.hyper_param['L3']['padding'],
  175. name='L3_conv')
  176. l3_conv = tf.add(l3_conv, self.get_bias(3))
  177. l3_conv = tf.nn.relu(l3_conv)
  178. # Layer 4 Convolutional layer
  179. filter4 = self.get_filter(4, 'L4_filter')
  180. l4_conv = tf.nn.conv2d(l3_conv, filter4,
  181. self.get_strides(4),
  182. padding = self.hyper_param['L4']['padding'],
  183. name='L4_conv')
  184. l4_conv = tf.add(l4_conv, self.get_bias(4, 1.0))
  185. l4_conv = tf.nn.relu(l4_conv)
  186. # Layer 5 Convolutional layer
  187. filter5 = self.get_filter(5, 'L5_filter')
  188. l5_conv = tf.nn.conv2d(l4_conv, filter5,
  189. self.get_strides(5),
  190. padding = self.hyper_param['L5']['padding'],
  191. name='L5_conv')
  192. l5_conv = tf.add(l5_conv, self.get_bias(5, 1.0))
  193. l5_conv = tf.nn.relu(l5_conv)
  194. # Layer 5 Max Pooling layer
  195. l5_MP = tf.layers.max_pooling2d(l5_conv,
  196. self.hyper_param['L5_MP']['filter_size'],
  197. self.hyper_param['L5_MP']['stride'],
  198. name='L5_MP')
  199. flatten = tf.layers.flatten(l5_MP)
  200. # Layer 6 Fully connected layer
  201. l6_FC = tf.contrib.layers.fully_connected(flatten,
  202. self.hyper_param['FC6'])
  203. # Dropout layer
  204. l6_dropout = tf.nn.dropout(l6_FC, self.dropout,
  205. name='l6_dropout')
  206. # Layer 7 Fully connected layer
  207. self.l7_FC = tf.contrib.layers.fully_connected(l6_dropout,
  208. self.hyper_param['FC7'])
  209. # Dropout layer
  210. l7_dropout = tf.nn.dropout(self.l7_FC, self.dropout,
  211. name='l7_dropout')
  212. # final layer before softmax
  213. self.logits = tf.contrib.layers.fully_connected(l7_dropout,
  214. self.num_classes, None)
  215. # loss function
  216. loss_function = tf.nn.softmax_cross_entropy_with_logits(
  217. logits = self.logits,
  218. labels = self.labels
  219. )
  220. # total loss
  221. self.loss = tf.reduce_mean(loss_function) + self.l2_loss
  222. self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, momentum=self.momentum)\
  223. .minimize(self.loss, global_step=self.global_step)
  224. correct = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.labels, 1))
  225. self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
  226. self.top5_correct = tf.nn.in_top_k(self.logits, tf.argmax(self.labels, 1), 5)
  227. self.top5_accuracy = tf.reduce_mean(tf.cast(self.top5_correct, tf.float32))
  228. self.add_summaries()
  229. def add_summaries(self):
  230. """
  231. Add summaries for loss, top1 and top5 accuracies
  232. Add loss, top1 and top5 accuracies to summary files
  233. in order to visualize in tensorboard
  234. """
  235. tf.summary.scalar('loss', self.loss)
  236. tf.summary.scalar('Top-1-Acc', self.accuracy)
  237. tf.summary.scalar('Top-5-Acc', self.top5_accuracy)
  238. self.merged = tf.summary.merge_all()
  239. def save_model(self, sess, saver):
  240. """
  241. Save the current model
  242. :param sess: Session object
  243. :param saver: Saver object responsible to store
  244. """
  245. model_base_path = os.path.join(os.getcwd(), 'model')
  246. if not os.path.exists(model_base_path):
  247. os.mkdir(model_base_path)
  248. model_save_path = os.path.join(os.getcwd(), 'model', 'model.ckpt')
  249. save_path = saver.save(sess, model_save_path)
  250. self.logger.info("Model saved in path: %s", save_path)
  251. def restore_model(self, sess, saver):
  252. """
  253. Restore previously saved model
  254. :param sess: Session object
  255. :param saver: Saver object responsible to store
  256. """
  257. model_base_path = os.path.join(os.getcwd(), 'model')
  258. model_restore_path = os.path.join(os.getcwd(), 'model', 'model.ckpt')
  259. saver.restore(sess, model_restore_path)
  260. self.logger.info("Model Restored from path: %s",
  261. model_restore_path)
  262. def get_summary_writer(self, sess):
  263. """
  264. Get summary writer for training and validation
  265. Responsible for creating summary writer so it can
  266. write summaries to a file so it can be read by
  267. tensorboard later.
  268. """
  269. if not os.path.exists(os.path.join('summary', 'train')):
  270. os.makedirs(os.path.join('summary', 'train'))
  271. if not os.path.exists(os.path.join('summary', 'val')):
  272. os.makedirs(os.path.join('summary', 'val'))
  273. return (tf.summary.FileWriter(os.path.join(os.getcwd(),
  274. 'summary', 'train'),
  275. sess.graph),
  276. tf.summary.FileWriter(os.path.join(os.getcwd(),
  277. 'summary', 'val'),
  278. sess.graph))
  279. def train(self, epochs, thread='false'):
  280. """
  281. Train AlexNet.
  282. """
  283. batch_step, val_step = 10, 500
  284. self.logger.info("Building the graph...")
  285. self.build_graph()
  286. init = tf.global_variables_initializer()
  287. saver = tf.train.Saver()
  288. with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
  289. (summary_writer_train,
  290. summary_writer_val) = self.get_summary_writer(sess)
  291. if self.resume and os.path.exists(os.path.join(os.getcwd(),
  292. 'model')):
  293. self.restore_model(sess, saver)
  294. else:
  295. sess.run(init)
  296. resume_batch = True
  297. best_loss = float('inf')
  298. while sess.run(self.cur_epoch) < epochs:
  299. losses = []
  300. accuracies = []
  301. epoch = sess.run(self.cur_epoch)
  302. if not self.resume or (
  303. self.resume and not resume_batch):
  304. sess.run(self.init_batch_op)
  305. resume_batch = False
  306. start = time.time()
  307. gen_batch = self.lsvrc2010.gen_batch
  308. for images, labels in gen_batch:
  309. batch_i = sess.run(self.cur_batch)
  310. # If it's resumed from stored model,
  311. # this will save from messing up the batch number
  312. # in subsequent epoch
  313. if batch_i >= ceil(len(self.lsvrc2010.image_names) / self.batch_size):
  314. break
  315. (_, global_step,
  316. _) = sess.run([self.optimizer,
  317. self.global_step, self.increment_batch_op],
  318. feed_dict = {
  319. self.input_image: images,
  320. self.labels: labels,
  321. self.learning_rate: self.lr,
  322. self.dropout: 0.5
  323. })
  324. if global_step == 150000:
  325. self.lr = 0.0001 # Halve the learning rate
  326. if batch_i % batch_step == 0:
  327. (summary, loss, acc, top5_acc, _top5,
  328. logits, l7_FC) = sess.run([self.merged, self.loss,
  329. self.accuracy, self.top5_accuracy,
  330. self.top5_correct,
  331. self.logits, self.l7_FC],
  332. feed_dict = {
  333. self.input_image: images,
  334. self.labels: labels,
  335. self.learning_rate: self.lr,
  336. self.dropout: 1.0
  337. })
  338. losses.append(loss)
  339. accuracies.append(acc)
  340. summary_writer_train.add_summary(summary, global_step)
  341. summary_writer_train.flush()
  342. end = time.time()
  343. try:
  344. self.logger.debug("l7 no of non zeros: %d", np.count_nonzero(l7_FC))
  345. true_idx = np.where(_top5[0]==True)[0][0]
  346. self.logger.debug("logit at %d: %s", true_idx,
  347. str(logits[true_idx]))
  348. except IndexError as ie:
  349. self.logger.debug(ie)
  350. self.logger.info("Time: %f Epoch: %d Batch: %d Loss: %f "
  351. "Avg loss: %f Accuracy: %f Avg Accuracy: %f "
  352. "Top 5 Accuracy: %f",
  353. end - start, epoch, batch_i,
  354. loss, sum(losses) / len(losses),
  355. acc, sum(accuracies) / len(accuracies),
  356. top5_acc)
  357. start = time.time()
  358. if batch_i % val_step == 0:
  359. images_val, labels_val = self.lsvrc2010.get_batch_val
  360. (summary, acc, top5_acc,
  361. loss) = sess.run([self.merged,
  362. self.accuracy,
  363. self.top5_accuracy, self.loss],
  364. feed_dict = {
  365. self.input_image: images_val,
  366. self.labels: labels_val,
  367. self.learning_rate: self.lr,
  368. self.dropout: 1.0
  369. })
  370. summary_writer_val.add_summary(summary, global_step)
  371. summary_writer_val.flush()
  372. self.logger.info("Validation - Accuracy: %f Top 5 Accuracy: %f Loss: %f",
  373. acc, top5_acc, loss)
  374. cur_loss = sum(losses) / len(losses)
  375. if cur_loss < best_loss:
  376. best_loss = cur_loss
  377. self.save_model(sess, saver)
  378. # Increase epoch number
  379. sess.run(self.increment_epoch_op)
  380. def test(self):
  381. step = 10
  382. self.logger_test = logs.get_logger('AlexNetTest', file_name='logs_test.log')
  383. self.logger_test.info("In Test: Building the graph...")
  384. self.build_graph()
  385. init = tf.global_variables_initializer()
  386. saver = tf.train.Saver()
  387. top1_count, top5_count, count = 0, 0, 0
  388. with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
  389. self.restore_model(sess, saver)
  390. start = time.time()
  391. batch = self.lsvrc2010.gen_batch_test
  392. for i, (patches, labels) in enumerate(batch):
  393. count += patches[0].shape[0]
  394. avg_logits = np.zeros((patches[0].shape[0], self.num_classes))
  395. for patch in patches:
  396. logits = sess.run(self.logits,
  397. feed_dict = {
  398. self.input_image: patch,
  399. self.dropout: 1.0
  400. })
  401. avg_logits += logits
  402. avg_logits /= len(patches)
  403. top1_count += np.sum(np.argmax(avg_logits, 1) == labels)
  404. top5_count += np.sum(avg_logits.argsort()[:, -5:] == \
  405. np.repeat(labels, 5).reshape(patches[0].shape[0], 5))
  406. if i % step == 0:
  407. end = time.time()
  408. self.logger_test.info("Time: %f Step: %d "
  409. "Avg Accuracy: %f "
  410. "Avg Top 5 Accuracy: %f",
  411. end - start, i,
  412. top1_count / count,
  413. top5_count / count)
  414. start = time.time()
  415. self.logger_test.info("Final - Avg Accuracy: %f "
  416. "Avg Top 5 Accuracy: %f",
  417. top1_count / count,
  418. top5_count / count)
  419. if __name__ == '__main__':
  420. import argparse
  421. parser = argparse.ArgumentParser()
  422. parser.add_argument('image_path', metavar = 'image-path',
  423. help = 'ImageNet dataset path')
  424. parser.add_argument('--resume', metavar='resume',
  425. type=lambda x: x != 'False', default=True,
  426. required=False,
  427. help='Resume training (True or False)')
  428. parser.add_argument('--train', help='Train AlexNet')
  429. parser.add_argument('--test', help='Test AlexNet')
  430. args = parser.parse_args()
  431. alexnet = AlexNet(args.image_path, batch_size=128, resume=args.resume)
  432. if args.train == 'true':
  433. alexnet.train(50)
  434. elif args.test == 'true':
  435. alexnet.test()