大作业仓库
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
11 KiB

import tensorflow as tf
import seq2seq
import bleu
import reader
from os import path
import random
class Model():
def __init__(self, train_input_file, train_target_file,
test_input_file, test_target_file, vocab_file,
num_units, layers, dropout,
batch_size, learning_rate, output_dir,
save_step = 100, eval_step = 1000,
param_histogram=False, restore_model=False,
init_train=True, init_infer=False):
self.num_units = num_units
self.layers = layers
self.dropout = dropout
self.batch_size = batch_size
self.learning_rate = learning_rate
self.save_step = save_step
self.eval_step = eval_step
self.param_histogram = param_histogram
self.restore_model = restore_model
self.init_train = init_train
self.init_infer = init_infer
if init_train:
self.train_reader = reader.SeqReader(train_input_file,
train_target_file, vocab_file, batch_size)
self.train_reader.start()
self.train_data = self.train_reader.read()
self.eval_reader = reader.SeqReader(test_input_file, test_target_file,
vocab_file, batch_size)
self.eval_reader.start()
self.eval_data = self.eval_reader.read()
self.model_file = path.join(output_dir, 'model.ckpl')
self.log_writter = tf.summary.FileWriter(output_dir)
if init_train:
self._init_train()
self._init_eval()
if init_infer:
self.infer_vocabs = reader.read_vocab(vocab_file)
self.infer_vocab_indices = dict((c, i) for i, c in
enumerate(self.infer_vocabs))
self._init_infer()
self.reload_infer_model()
def gpu_session_config(self):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
return config
def _init_train(self):
self.train_graph = tf.Graph()
with self.train_graph.as_default():
self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len,
self.train_target_seq, self.train_target_seq_len,
len(self.train_reader.vocabs),
self.num_units, self.layers, self.dropout)
self.train_output = tf.argmax(tf.nn.softmax(output), 2)
self.loss = seq2seq.seq_loss(output, self.train_target_seq,
self.train_target_seq_len)
params = tf.trainable_variables()
gradients = tf.gradients(self.loss, params)
clipped_gradients, _ = tf.clip_by_global_norm(
gradients, 0.5)
self.train_op = tf.train.AdamOptimizer(
learning_rate=self.learning_rate
).apply_gradients(zip(clipped_gradients,params))
if self.param_histogram:
for v in tf.trainable_variables():
tf.summary.histogram('train_' + v.name, v)
tf.summary.scalar('loss', self.loss)
self.train_summary = tf.summary.merge_all()
self.train_init = tf.global_variables_initializer()
self.train_saver = tf.train.Saver()
self.train_session = tf.Session(graph=self.train_graph,
config=self.gpu_session_config())
def _init_eval(self):
self.eval_graph = tf.Graph()
with self.eval_graph.as_default():
self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
self.eval_output = seq2seq.seq2seq(self.eval_in_seq,
self.eval_in_seq_len, None, None,
len(self.eval_reader.vocabs),
self.num_units, self.layers, self.dropout)
if self.param_histogram:
for v in tf.trainable_variables():
tf.summary.histogram('eval_' + v.name, v)
self.eval_summary = tf.summary.merge_all()
self.eval_saver = tf.train.Saver()
self.eval_session = tf.Session(graph=self.eval_graph,
config=self.gpu_session_config())
def _init_infer(self):
self.infer_graph = tf.Graph()
with self.infer_graph.as_default():
self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None])
self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1])
self.infer_output = seq2seq.seq2seq(self.infer_in_seq,
self.infer_in_seq_len, None, None,
len(self.infer_vocabs),
self.num_units, self.layers, self.dropout)
self.infer_saver = tf.train.Saver()
self.infer_session = tf.Session(graph=self.infer_graph,
config=self.gpu_session_config())
def train(self, epochs, start=0):
if not self.init_train:
raise Exception('Train graph is not inited!')
with self.train_graph.as_default():
if path.isfile(self.model_file + '.meta') and self.restore_model:
print("Reloading model file before training.")
self.train_saver.restore(self.train_session, self.model_file)
else:
self.train_session.run(self.train_init)
total_loss = 0
for step in range(start, epochs):
data = next(self.train_data)
in_seq = data['in_seq']
in_seq_len = data['in_seq_len']
target_seq = data['target_seq']
target_seq_len = data['target_seq_len']
output, loss, train, summary = self.train_session.run(
[self.train_output, self.loss, self.train_op, self.train_summary],
feed_dict={
self.train_in_seq: in_seq,
self.train_in_seq_len: in_seq_len,
self.train_target_seq: target_seq,
self.train_target_seq_len: target_seq_len})
total_loss += loss
self.log_writter.add_summary(summary, step)
if step % self.save_step == 0:
self.train_saver.save(self.train_session, self.model_file)
print("Saving model. Step: %d, loss: %f" % (step,
total_loss / self.save_step))
# print sample output
sid = random.randint(0, self.batch_size-1)
input_text = reader.decode_text(in_seq[sid],
self.eval_reader.vocabs)
output_text = reader.decode_text(output[sid],
self.train_reader.vocabs)
target_text = reader.decode_text(target_seq[sid],
self.train_reader.vocabs).split(' ')[1:]
target_text = ' '.join(target_text)
print('******************************')
print('src: ' + input_text)
print('output: ' + output_text)
print('target: ' + target_text)
if step % self.eval_step == 0:
bleu_score = self.eval(step)
print("Evaluate model. Step: %d, score: %f, loss: %f" % (
step, bleu_score, total_loss / self.save_step))
eval_summary = tf.Summary(value=[tf.Summary.Value(
tag='bleu', simple_value=bleu_score)])
self.log_writter.add_summary(eval_summary, step)
if step % self.save_step == 0:
total_loss = 0
def eval(self, train_step):
with self.eval_graph.as_default():
self.eval_saver.restore(self.eval_session, self.model_file)
bleu_score = 0
target_results = []
output_results = []
for step in range(0, self.eval_reader.data_size):
data = next(self.eval_data)
in_seq = data['in_seq']
in_seq_len = data['in_seq_len']
target_seq = data['target_seq']
target_seq_len = data['target_seq_len']
outputs = self.eval_session.run(
self.eval_output,
feed_dict={
self.eval_in_seq: in_seq,
self.eval_in_seq_len: in_seq_len})
for i in range(len(outputs)):
output = outputs[i]
target = target_seq[i]
output_text = reader.decode_text(output,
self.eval_reader.vocabs).split(' ')
target_text = reader.decode_text(target[1:],
self.eval_reader.vocabs).split(' ')
prob = int(self.eval_reader.data_size * self.batch_size / 10)
target_results.append([target_text])
output_results.append(output_text)
if random.randint(1, prob) == 1:
print('====================')
input_text = reader.decode_text(in_seq[i],
self.eval_reader.vocabs)
print('src:' + input_text)
print('output: ' + ' '.join(output_text))
print('target: ' + ' '.join(target_text))
return bleu.compute_bleu(target_results, output_results)[0] * 100
def reload_infer_model(self):
with self.infer_graph.as_default():
self.infer_saver.restore(self.infer_session, self.model_file)
def infer(self, text):
if not self.init_infer:
raise Exception('Infer graph is not inited!')
with self.infer_graph.as_default():
in_seq = reader.encode_text(text.split(' ') + ['</s>',],
self.infer_vocab_indices)
in_seq_len = len(in_seq)
outputs = self.infer_session.run(self.infer_output,
feed_dict={
self.infer_in_seq: [in_seq],
self.infer_in_seq_len: [in_seq_len]})
output = outputs[0]
output_text = reader.decode_text(output, self.infer_vocabs)
return output_text