|
import click
|
|
|
|
|
|
from model.utils.data_generator import DataGenerator
|
|
from model.img2seq import Img2SeqModel
|
|
from model.utils.lr_schedule import LRSchedule
|
|
from model.utils.general import Config
|
|
from model.utils.text import Vocab
|
|
from model.utils.image import greyscale
|
|
|
|
|
|
@click.command()
|
|
@click.option('--data', default="configs/data.json",
|
|
help='Path to data json config')
|
|
@click.option('--vocab', default="configs/vocab.json",
|
|
help='Path to vocab json config')
|
|
@click.option('--training', default="configs/training.json",
|
|
help='Path to training json config')
|
|
@click.option('--model', default="configs/model.json",
|
|
help='Path to model json config')
|
|
@click.option('--output', default="results/full/",
|
|
help='Dir for results and model weights')
|
|
def main(data, vocab, training, model, output):
|
|
# Load configs
|
|
dir_output = output
|
|
config = Config([data, vocab, training, model])
|
|
config.save(dir_output)
|
|
vocab = Vocab(config)
|
|
|
|
# Load datasets
|
|
train_set = DataGenerator(path_formulas=config.path_formulas_train,
|
|
dir_images=config.dir_images_train,
|
|
max_iter=config.max_iter, bucket=config.bucket_train,
|
|
path_matching=config.path_matching_train,
|
|
max_len=config.max_length_formula,
|
|
form_prepro=vocab.form_prepro)
|
|
val_set = DataGenerator(path_formulas=config.path_formulas_val,
|
|
dir_images=config.dir_images_val,
|
|
max_iter=config.max_iter, bucket=config.bucket_val,
|
|
path_matching=config.path_matching_val,
|
|
max_len=config.max_length_formula,
|
|
form_prepro=vocab.form_prepro)
|
|
|
|
# Define learning rate schedule
|
|
n_batches_epoch = ((len(train_set) + config.batch_size - 1) //
|
|
config.batch_size)
|
|
lr_schedule = LRSchedule(lr_init=config.lr_init,
|
|
start_decay=config.start_decay*n_batches_epoch,
|
|
end_decay=config.end_decay*n_batches_epoch,
|
|
end_warm=config.end_warm*n_batches_epoch,
|
|
lr_warm=config.lr_warm,
|
|
lr_min=config.lr_min)
|
|
|
|
# Build model and train
|
|
model = Img2SeqModel(config, dir_output, vocab)
|
|
model.build_train(config)
|
|
#model.restore_session(dir_output + "model.weights/test-model.ckpt")
|
|
model.train(config, train_set, val_set, lr_schedule)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|