import tensorflow as tf import seq2seq import bleu import reader from os import path import random class Model(): def __init__(self, train_input_file, train_target_file, test_input_file, test_target_file, vocab_file, num_units, layers, dropout, batch_size, learning_rate, output_dir, save_step = 100, eval_step = 1000, param_histogram=False, restore_model=False, init_train=True, init_infer=False): self.num_units = num_units self.layers = layers self.dropout = dropout self.batch_size = batch_size self.learning_rate = learning_rate self.save_step = save_step self.eval_step = eval_step self.param_histogram = param_histogram self.restore_model = restore_model self.init_train = init_train self.init_infer = init_infer if init_train: self.train_reader = reader.SeqReader(train_input_file, train_target_file, vocab_file, batch_size) self.train_reader.start() self.train_data = self.train_reader.read() self.eval_reader = reader.SeqReader(test_input_file, test_target_file, vocab_file, batch_size) self.eval_reader.start() self.eval_data = self.eval_reader.read() self.model_file = path.join(output_dir, 'model.ckpl') self.log_writter = tf.summary.FileWriter(output_dir) if init_train: self._init_train() self._init_eval() if init_infer: self.infer_vocabs = reader.read_vocab(vocab_file) self.infer_vocab_indices = dict((c, i) for i, c in enumerate(self.infer_vocabs)) self._init_infer() self.reload_infer_model() def gpu_session_config(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True return config def _init_train(self): self.train_graph = tf.Graph() with self.train_graph.as_default(): self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len, self.train_target_seq, self.train_target_seq_len, len(self.train_reader.vocabs), self.num_units, self.layers, self.dropout) self.train_output = tf.argmax(tf.nn.softmax(output), 2) self.loss = seq2seq.seq_loss(output, self.train_target_seq, self.train_target_seq_len) params = tf.trainable_variables() gradients = tf.gradients(self.loss, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, 0.5) self.train_op = tf.train.AdamOptimizer( learning_rate=self.learning_rate ).apply_gradients(zip(clipped_gradients,params)) if self.param_histogram: for v in tf.trainable_variables(): tf.summary.histogram('train_' + v.name, v) tf.summary.scalar('loss', self.loss) self.train_summary = tf.summary.merge_all() self.train_init = tf.global_variables_initializer() self.train_saver = tf.train.Saver() self.train_session = tf.Session(graph=self.train_graph, config=self.gpu_session_config()) def _init_eval(self): self.eval_graph = tf.Graph() with self.eval_graph.as_default(): self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) self.eval_output = seq2seq.seq2seq(self.eval_in_seq, self.eval_in_seq_len, None, None, len(self.eval_reader.vocabs), self.num_units, self.layers, self.dropout) if self.param_histogram: for v in tf.trainable_variables(): tf.summary.histogram('eval_' + v.name, v) self.eval_summary = tf.summary.merge_all() self.eval_saver = tf.train.Saver() self.eval_session = tf.Session(graph=self.eval_graph, config=self.gpu_session_config()) def _init_infer(self): self.infer_graph = tf.Graph() with self.infer_graph.as_default(): self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None]) self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1]) self.infer_output = seq2seq.seq2seq(self.infer_in_seq, self.infer_in_seq_len, None, None, len(self.infer_vocabs), self.num_units, self.layers, self.dropout) self.infer_saver = tf.train.Saver() self.infer_session = tf.Session(graph=self.infer_graph, config=self.gpu_session_config()) def train(self, epochs, start=0): if not self.init_train: raise Exception('Train graph is not inited!') with self.train_graph.as_default(): if path.isfile(self.model_file + '.meta') and self.restore_model: print("Reloading model file before training.") self.train_saver.restore(self.train_session, self.model_file) else: self.train_session.run(self.train_init) total_loss = 0 for step in range(start, epochs): data = next(self.train_data) in_seq = data['in_seq'] in_seq_len = data['in_seq_len'] target_seq = data['target_seq'] target_seq_len = data['target_seq_len'] output, loss, train, summary = self.train_session.run( [self.train_output, self.loss, self.train_op, self.train_summary], feed_dict={ self.train_in_seq: in_seq, self.train_in_seq_len: in_seq_len, self.train_target_seq: target_seq, self.train_target_seq_len: target_seq_len}) total_loss += loss self.log_writter.add_summary(summary, step) if step % self.save_step == 0: self.train_saver.save(self.train_session, self.model_file) print("Saving model. Step: %d, loss: %f" % (step, total_loss / self.save_step)) # print sample output sid = random.randint(0, self.batch_size-1) input_text = reader.decode_text(in_seq[sid], self.eval_reader.vocabs) output_text = reader.decode_text(output[sid], self.train_reader.vocabs) target_text = reader.decode_text(target_seq[sid], self.train_reader.vocabs).split(' ')[1:] target_text = ' '.join(target_text) print('******************************') print('src: ' + input_text) print('output: ' + output_text) print('target: ' + target_text) if step % self.eval_step == 0: bleu_score = self.eval(step) print("Evaluate model. Step: %d, score: %f, loss: %f" % ( step, bleu_score, total_loss / self.save_step)) eval_summary = tf.Summary(value=[tf.Summary.Value( tag='bleu', simple_value=bleu_score)]) self.log_writter.add_summary(eval_summary, step) if step % self.save_step == 0: total_loss = 0 def eval(self, train_step): with self.eval_graph.as_default(): self.eval_saver.restore(self.eval_session, self.model_file) bleu_score = 0 target_results = [] output_results = [] for step in range(0, self.eval_reader.data_size): data = next(self.eval_data) in_seq = data['in_seq'] in_seq_len = data['in_seq_len'] target_seq = data['target_seq'] target_seq_len = data['target_seq_len'] outputs = self.eval_session.run( self.eval_output, feed_dict={ self.eval_in_seq: in_seq, self.eval_in_seq_len: in_seq_len}) for i in range(len(outputs)): output = outputs[i] target = target_seq[i] output_text = reader.decode_text(output, self.eval_reader.vocabs).split(' ') target_text = reader.decode_text(target[1:], self.eval_reader.vocabs).split(' ') prob = int(self.eval_reader.data_size * self.batch_size / 10) target_results.append([target_text]) output_results.append(output_text) if random.randint(1, prob) == 1: print('====================') input_text = reader.decode_text(in_seq[i], self.eval_reader.vocabs) print('src:' + input_text) print('output: ' + ' '.join(output_text)) print('target: ' + ' '.join(target_text)) return bleu.compute_bleu(target_results, output_results)[0] * 100 def reload_infer_model(self): with self.infer_graph.as_default(): self.infer_saver.restore(self.infer_session, self.model_file) def infer(self, text): if not self.init_infer: raise Exception('Infer graph is not inited!') with self.infer_graph.as_default(): in_seq = reader.encode_text(text.split(' ') + ['</s>',], self.infer_vocab_indices) in_seq_len = len(in_seq) outputs = self.infer_session.run(self.infer_output, feed_dict={ self.infer_in_seq: [in_seq], self.infer_in_seq_len: [in_seq_len]}) output = outputs[0] output_text = reader.decode_text(output, self.infer_vocabs) return output_text