大作业仓库
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
11 KiB

преди 3 години
  1. import tensorflow as tf
  2. import seq2seq
  3. import bleu
  4. import reader
  5. from os import path
  6. import random
  7. class Model():
  8. def __init__(self, train_input_file, train_target_file,
  9. test_input_file, test_target_file, vocab_file,
  10. num_units, layers, dropout,
  11. batch_size, learning_rate, output_dir,
  12. save_step = 100, eval_step = 1000,
  13. param_histogram=False, restore_model=False,
  14. init_train=True, init_infer=False):
  15. self.num_units = num_units
  16. self.layers = layers
  17. self.dropout = dropout
  18. self.batch_size = batch_size
  19. self.learning_rate = learning_rate
  20. self.save_step = save_step
  21. self.eval_step = eval_step
  22. self.param_histogram = param_histogram
  23. self.restore_model = restore_model
  24. self.init_train = init_train
  25. self.init_infer = init_infer
  26. if init_train:
  27. self.train_reader = reader.SeqReader(train_input_file,
  28. train_target_file, vocab_file, batch_size)
  29. self.train_reader.start()
  30. self.train_data = self.train_reader.read()
  31. self.eval_reader = reader.SeqReader(test_input_file, test_target_file,
  32. vocab_file, batch_size)
  33. self.eval_reader.start()
  34. self.eval_data = self.eval_reader.read()
  35. self.model_file = path.join(output_dir, 'model.ckpl')
  36. self.log_writter = tf.summary.FileWriter(output_dir)
  37. if init_train:
  38. self._init_train()
  39. self._init_eval()
  40. if init_infer:
  41. self.infer_vocabs = reader.read_vocab(vocab_file)
  42. self.infer_vocab_indices = dict((c, i) for i, c in
  43. enumerate(self.infer_vocabs))
  44. self._init_infer()
  45. self.reload_infer_model()
  46. def gpu_session_config(self):
  47. config = tf.ConfigProto()
  48. config.gpu_options.allow_growth = True
  49. return config
  50. def _init_train(self):
  51. self.train_graph = tf.Graph()
  52. with self.train_graph.as_default():
  53. self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  54. self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  55. self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  56. self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  57. output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len,
  58. self.train_target_seq, self.train_target_seq_len,
  59. len(self.train_reader.vocabs),
  60. self.num_units, self.layers, self.dropout)
  61. self.train_output = tf.argmax(tf.nn.softmax(output), 2)
  62. self.loss = seq2seq.seq_loss(output, self.train_target_seq,
  63. self.train_target_seq_len)
  64. params = tf.trainable_variables()
  65. gradients = tf.gradients(self.loss, params)
  66. clipped_gradients, _ = tf.clip_by_global_norm(
  67. gradients, 0.5)
  68. self.train_op = tf.train.AdamOptimizer(
  69. learning_rate=self.learning_rate
  70. ).apply_gradients(zip(clipped_gradients,params))
  71. if self.param_histogram:
  72. for v in tf.trainable_variables():
  73. tf.summary.histogram('train_' + v.name, v)
  74. tf.summary.scalar('loss', self.loss)
  75. self.train_summary = tf.summary.merge_all()
  76. self.train_init = tf.global_variables_initializer()
  77. self.train_saver = tf.train.Saver()
  78. self.train_session = tf.Session(graph=self.train_graph,
  79. config=self.gpu_session_config())
  80. def _init_eval(self):
  81. self.eval_graph = tf.Graph()
  82. with self.eval_graph.as_default():
  83. self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
  84. self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
  85. self.eval_output = seq2seq.seq2seq(self.eval_in_seq,
  86. self.eval_in_seq_len, None, None,
  87. len(self.eval_reader.vocabs),
  88. self.num_units, self.layers, self.dropout)
  89. if self.param_histogram:
  90. for v in tf.trainable_variables():
  91. tf.summary.histogram('eval_' + v.name, v)
  92. self.eval_summary = tf.summary.merge_all()
  93. self.eval_saver = tf.train.Saver()
  94. self.eval_session = tf.Session(graph=self.eval_graph,
  95. config=self.gpu_session_config())
  96. def _init_infer(self):
  97. self.infer_graph = tf.Graph()
  98. with self.infer_graph.as_default():
  99. self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None])
  100. self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1])
  101. self.infer_output = seq2seq.seq2seq(self.infer_in_seq,
  102. self.infer_in_seq_len, None, None,
  103. len(self.infer_vocabs),
  104. self.num_units, self.layers, self.dropout)
  105. self.infer_saver = tf.train.Saver()
  106. self.infer_session = tf.Session(graph=self.infer_graph,
  107. config=self.gpu_session_config())
  108. def train(self, epochs, start=0):
  109. if not self.init_train:
  110. raise Exception('Train graph is not inited!')
  111. with self.train_graph.as_default():
  112. if path.isfile(self.model_file + '.meta') and self.restore_model:
  113. print("Reloading model file before training.")
  114. self.train_saver.restore(self.train_session, self.model_file)
  115. else:
  116. self.train_session.run(self.train_init)
  117. total_loss = 0
  118. for step in range(start, epochs):
  119. data = next(self.train_data)
  120. in_seq = data['in_seq']
  121. in_seq_len = data['in_seq_len']
  122. target_seq = data['target_seq']
  123. target_seq_len = data['target_seq_len']
  124. output, loss, train, summary = self.train_session.run(
  125. [self.train_output, self.loss, self.train_op, self.train_summary],
  126. feed_dict={
  127. self.train_in_seq: in_seq,
  128. self.train_in_seq_len: in_seq_len,
  129. self.train_target_seq: target_seq,
  130. self.train_target_seq_len: target_seq_len})
  131. total_loss += loss
  132. self.log_writter.add_summary(summary, step)
  133. if step % self.save_step == 0:
  134. self.train_saver.save(self.train_session, self.model_file)
  135. print("Saving model. Step: %d, loss: %f" % (step,
  136. total_loss / self.save_step))
  137. # print sample output
  138. sid = random.randint(0, self.batch_size-1)
  139. input_text = reader.decode_text(in_seq[sid],
  140. self.eval_reader.vocabs)
  141. output_text = reader.decode_text(output[sid],
  142. self.train_reader.vocabs)
  143. target_text = reader.decode_text(target_seq[sid],
  144. self.train_reader.vocabs).split(' ')[1:]
  145. target_text = ' '.join(target_text)
  146. print('******************************')
  147. print('src: ' + input_text)
  148. print('output: ' + output_text)
  149. print('target: ' + target_text)
  150. if step % self.eval_step == 0:
  151. bleu_score = self.eval(step)
  152. print("Evaluate model. Step: %d, score: %f, loss: %f" % (
  153. step, bleu_score, total_loss / self.save_step))
  154. eval_summary = tf.Summary(value=[tf.Summary.Value(
  155. tag='bleu', simple_value=bleu_score)])
  156. self.log_writter.add_summary(eval_summary, step)
  157. if step % self.save_step == 0:
  158. total_loss = 0
  159. def eval(self, train_step):
  160. with self.eval_graph.as_default():
  161. self.eval_saver.restore(self.eval_session, self.model_file)
  162. bleu_score = 0
  163. target_results = []
  164. output_results = []
  165. for step in range(0, self.eval_reader.data_size):
  166. data = next(self.eval_data)
  167. in_seq = data['in_seq']
  168. in_seq_len = data['in_seq_len']
  169. target_seq = data['target_seq']
  170. target_seq_len = data['target_seq_len']
  171. outputs = self.eval_session.run(
  172. self.eval_output,
  173. feed_dict={
  174. self.eval_in_seq: in_seq,
  175. self.eval_in_seq_len: in_seq_len})
  176. for i in range(len(outputs)):
  177. output = outputs[i]
  178. target = target_seq[i]
  179. output_text = reader.decode_text(output,
  180. self.eval_reader.vocabs).split(' ')
  181. target_text = reader.decode_text(target[1:],
  182. self.eval_reader.vocabs).split(' ')
  183. prob = int(self.eval_reader.data_size * self.batch_size / 10)
  184. target_results.append([target_text])
  185. output_results.append(output_text)
  186. if random.randint(1, prob) == 1:
  187. print('====================')
  188. input_text = reader.decode_text(in_seq[i],
  189. self.eval_reader.vocabs)
  190. print('src:' + input_text)
  191. print('output: ' + ' '.join(output_text))
  192. print('target: ' + ' '.join(target_text))
  193. return bleu.compute_bleu(target_results, output_results)[0] * 100
  194. def reload_infer_model(self):
  195. with self.infer_graph.as_default():
  196. self.infer_saver.restore(self.infer_session, self.model_file)
  197. def infer(self, text):
  198. if not self.init_infer:
  199. raise Exception('Infer graph is not inited!')
  200. with self.infer_graph.as_default():
  201. in_seq = reader.encode_text(text.split(' ') + ['</s>',],
  202. self.infer_vocab_indices)
  203. in_seq_len = len(in_seq)
  204. outputs = self.infer_session.run(self.infer_output,
  205. feed_dict={
  206. self.infer_in_seq: [in_seq],
  207. self.infer_in_seq_len: [in_seq_len]})
  208. output = outputs[0]
  209. output_text = reader.decode_text(output, self.infer_vocabs)
  210. return output_text