大作业仓库
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

234 rader
11 KiB

  1. import tensorflow as tf
  2. import seq2seq
  3. import bleu
  4. import reader
  5. from os import path
  6. import random
  7. class Model():
  8. def __init__(self, train_input_file, train_target_file,
  9. test_input_file, test_target_file, vocab_file,
  10. num_units, layers, dropout,
  11. batch_size, learning_rate, output_dir,
  12. save_step = 100, eval_step = 1000,
  13. param_histogram=False, restore_model=False,
  14. init_train=True, init_infer=False):
  15. self.num_units = num_units
  16. self.layers = layers
  17. self.dropout = dropout
  18. self.batch_size = batch_size
  19. self.learning_rate = learning_rate
  20. self.save_step = save_step
  21. self.eval_step = eval_step
  22. self.param_histogram = param_histogram
  23. self.restore_model = restore_model
  24. self.init_train = init_train
  25. self.init_infer = init_infer
  26. if init_train:
  27. self.train_reader = reader.SeqReader(train_input_file,
  28. train_target_file, vocab_file, batch_size)
  29. self.train_reader.start()
  30. self.train_data = self.train_reader.read()
  31. self.eval_reader = reader.SeqReader(test_input_file, test_target_file,
  32. vocab_file, batch_size)
  33. self.eval_reader.start()
  34. self.eval_data = self.eval_reader.read()
  35. self.model_file = path.join(output_dir, 'model.ckpl')
  36. self.log_writter = tf.summary.FileWriter(output_dir)
  37. if init_train:
  38. self._init_train()
  39. self._init_eval()
  40. if init_infer:
  41. self.infer_vocabs = reader.read_vocab(vocab_file)
  42. self.infer_vocab_indices = dict((c, i) for i, c in
  43. enumerate(self.infer_vocabs))
  44. self._init_infer()
  45. self.reload_infer_model()
  46. def gpu_session_config(self):
  47. config = tf.ConfigProto()
  48. config.gpu_options.allow_growth = True
  49. return config
  50. def _init_train(self):
  51. self.train_graph = tf.Graph()
  52. with self.train_graph.as_default():
  53. self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  54. self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  55. self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  56. self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  57. output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len,
  58. self.train_target_seq, self.train_target_seq_len,
  59. len(self.train_reader.vocabs),
  60. self.num_units, self.layers, self.dropout)
  61. self.train_output = tf.argmax(tf.nn.softmax(output), 2)
  62. self.loss = seq2seq.seq_loss(output, self.train_target_seq,
  63. self.train_target_seq_len)
  64. params = tf.trainable_variables()
  65. gradients = tf.gradients(self.loss, params)
  66. clipped_gradients, _ = tf.clip_by_global_norm(
  67. gradients, 0.5)
  68. self.train_op = tf.train.AdamOptimizer(
  69. learning_rate=self.learning_rate
  70. ).apply_gradients(zip(clipped_gradients,params))
  71. if self.param_histogram:
  72. for v in tf.trainable_variables():
  73. tf.summary.histogram('train_' + v.name, v)
  74. tf.summary.scalar('loss', self.loss)
  75. self.train_summary = tf.summary.merge_all()
  76. self.train_init = tf.global_variables_initializer()
  77. self.train_saver = tf.train.Saver()
  78. self.train_session = tf.Session(graph=self.train_graph,
  79. config=self.gpu_session_config())
  80. def _init_eval(self):
  81. self.eval_graph = tf.Graph()
  82. with self.eval_graph.as_default():
  83. self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  84. self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  85. self.eval_output = seq2seq.seq2seq(self.eval_in_seq,
  86. self.eval_in_seq_len, None, None,
  87. len(self.eval_reader.vocabs),
  88. self.num_units, self.layers, self.dropout)
  89. if self.param_histogram:
  90. for v in tf.trainable_variables():
  91. tf.summary.histogram('eval_' + v.name, v)
  92. self.eval_summary = tf.summary.merge_all()
  93. self.eval_saver = tf.train.Saver()
  94. self.eval_session = tf.Session(graph=self.eval_graph,
  95. config=self.gpu_session_config())
  96. def _init_infer(self):
  97. self.infer_graph = tf.Graph()
  98. with self.infer_graph.as_default():
  99. self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None])
  100. self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1])
  101. self.infer_output = seq2seq.seq2seq(self.infer_in_seq,
  102. self.infer_in_seq_len, None, None,
  103. len(self.infer_vocabs),
  104. self.num_units, self.layers, self.dropout)
  105. self.infer_saver = tf.train.Saver()
  106. self.infer_session = tf.Session(graph=self.infer_graph,
  107. config=self.gpu_session_config())
  108. def train(self, epochs, start=0):
  109. if not self.init_train:
  110. raise Exception('Train graph is not inited!')
  111. with self.train_graph.as_default():
  112. if path.isfile(self.model_file + '.meta') and self.restore_model:
  113. print("Reloading model file before training.")
  114. self.train_saver.restore(self.train_session, self.model_file)
  115. else:
  116. self.train_session.run(self.train_init)
  117. total_loss = 0
  118. for step in range(start, epochs):
  119. data = next(self.train_data)
  120. in_seq = data['in_seq']
  121. in_seq_len = data['in_seq_len']
  122. target_seq = data['target_seq']
  123. target_seq_len = data['target_seq_len']
  124. output, loss, train, summary = self.train_session.run(
  125. [self.train_output, self.loss, self.train_op, self.train_summary],
  126. feed_dict={
  127. self.train_in_seq: in_seq,
  128. self.train_in_seq_len: in_seq_len,
  129. self.train_target_seq: target_seq,
  130. self.train_target_seq_len: target_seq_len})
  131. total_loss += loss
  132. self.log_writter.add_summary(summary, step)
  133. if step % self.save_step == 0:
  134. self.train_saver.save(self.train_session, self.model_file)
  135. print("Saving model. Step: %d, loss: %f" % (step,
  136. total_loss / self.save_step))
  137. # print sample output
  138. sid = random.randint(0, self.batch_size-1)
  139. input_text = reader.decode_text(in_seq[sid],
  140. self.eval_reader.vocabs)
  141. output_text = reader.decode_text(output[sid],
  142. self.train_reader.vocabs)
  143. target_text = reader.decode_text(target_seq[sid],
  144. self.train_reader.vocabs).split(' ')[1:]
  145. target_text = ' '.join(target_text)
  146. print('******************************')
  147. print('src: ' + input_text)
  148. print('output: ' + output_text)
  149. print('target: ' + target_text)
  150. if step % self.eval_step == 0:
  151. bleu_score = self.eval(step)
  152. print("Evaluate model. Step: %d, score: %f, loss: %f" % (
  153. step, bleu_score, total_loss / self.save_step))
  154. eval_summary = tf.Summary(value=[tf.Summary.Value(
  155. tag='bleu', simple_value=bleu_score)])
  156. self.log_writter.add_summary(eval_summary, step)
  157. if step % self.save_step == 0:
  158. total_loss = 0
  159. def eval(self, train_step):
  160. with self.eval_graph.as_default():
  161. self.eval_saver.restore(self.eval_session, self.model_file)
  162. bleu_score = 0
  163. target_results = []
  164. output_results = []
  165. for step in range(0, self.eval_reader.data_size):
  166. data = next(self.eval_data)
  167. in_seq = data['in_seq']
  168. in_seq_len = data['in_seq_len']
  169. target_seq = data['target_seq']
  170. target_seq_len = data['target_seq_len']
  171. outputs = self.eval_session.run(
  172. self.eval_output,
  173. feed_dict={
  174. self.eval_in_seq: in_seq,
  175. self.eval_in_seq_len: in_seq_len})
  176. for i in range(len(outputs)):
  177. output = outputs[i]
  178. target = target_seq[i]
  179. output_text = reader.decode_text(output,
  180. self.eval_reader.vocabs).split(' ')
  181. target_text = reader.decode_text(target[1:],
  182. self.eval_reader.vocabs).split(' ')
  183. prob = int(self.eval_reader.data_size * self.batch_size / 10)
  184. target_results.append([target_text])
  185. output_results.append(output_text)
  186. if random.randint(1, prob) == 1:
  187. print('====================')
  188. input_text = reader.decode_text(in_seq[i],
  189. self.eval_reader.vocabs)
  190. print('src:' + input_text)
  191. print('output: ' + ' '.join(output_text))
  192. print('target: ' + ' '.join(target_text))
  193. return bleu.compute_bleu(target_results, output_results)[0] * 100
  194. def reload_infer_model(self):
  195. with self.infer_graph.as_default():
  196. self.infer_saver.restore(self.infer_session, self.model_file)
  197. def infer(self, text):
  198. if not self.init_infer:
  199. raise Exception('Infer graph is not inited!')
  200. with self.infer_graph.as_default():
  201. in_seq = reader.encode_text(text.split(' ') + ['</s>',],
  202. self.infer_vocab_indices)
  203. in_seq_len = len(in_seq)
  204. outputs = self.infer_session.run(self.infer_output,
  205. feed_dict={
  206. self.infer_in_seq: [in_seq],
  207. self.infer_in_seq_len: [in_seq_len]})
  208. output = outputs[0]
  209. output_text = reader.decode_text(output, self.infer_vocabs)
  210. return output_text