NoteOnMe博客平台搭建
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
2.3 KiB

  1. import click
  2. from model.utils.data_generator import DataGenerator
  3. from model.img2seq import Img2SeqModel
  4. from model.utils.lr_schedule import LRSchedule
  5. from model.utils.general import Config
  6. from model.utils.text import Vocab
  7. from model.utils.image import greyscale
  8. @click.command()
  9. @click.option('--data', default="configs/data.json",
  10. help='Path to data json config')
  11. @click.option('--vocab', default="configs/vocab.json",
  12. help='Path to vocab json config')
  13. @click.option('--training', default="configs/training.json",
  14. help='Path to training json config')
  15. @click.option('--model', default="configs/model.json",
  16. help='Path to model json config')
  17. @click.option('--output', default="results/full/",
  18. help='Dir for results and model weights')
  19. def main(data, vocab, training, model, output):
  20. # Load configs
  21. dir_output = output
  22. config = Config([data, vocab, training, model])
  23. config.save(dir_output)
  24. vocab = Vocab(config)
  25. # Load datasets
  26. train_set = DataGenerator(path_formulas=config.path_formulas_train,
  27. dir_images=config.dir_images_train,
  28. max_iter=config.max_iter, bucket=config.bucket_train,
  29. path_matching=config.path_matching_train,
  30. max_len=config.max_length_formula,
  31. form_prepro=vocab.form_prepro)
  32. val_set = DataGenerator(path_formulas=config.path_formulas_val,
  33. dir_images=config.dir_images_val,
  34. max_iter=config.max_iter, bucket=config.bucket_val,
  35. path_matching=config.path_matching_val,
  36. max_len=config.max_length_formula,
  37. form_prepro=vocab.form_prepro)
  38. # Define learning rate schedule
  39. n_batches_epoch = ((len(train_set) + config.batch_size - 1) //
  40. config.batch_size)
  41. lr_schedule = LRSchedule(lr_init=config.lr_init,
  42. start_decay=config.start_decay*n_batches_epoch,
  43. end_decay=config.end_decay*n_batches_epoch,
  44. end_warm=config.end_warm*n_batches_epoch,
  45. lr_warm=config.lr_warm,
  46. lr_min=config.lr_min)
  47. # Build model and train
  48. model = Img2SeqModel(config, dir_output, vocab)
  49. model.build_train(config)
  50. #model.restore_session(dir_output + "model.weights/test-model.ckpt")
  51. model.train(config, train_set, val_set, lr_schedule)
  52. if __name__ == "__main__":
  53. main()