from math import ceil
|
|
from typing import Optional
|
|
import logging
|
|
from argparse import ArgumentParser
|
|
import sys
|
|
import os
|
|
|
|
class Config:
|
|
def __init__(self):
|
|
self.__logger: Optional[logging.Logger] = None
|
|
self.set_defaults()
|
|
|
|
def set_defaults(self):
|
|
self.NUM_TRAIN_EPOCHS = 100
|
|
self.SAVE_EVERY_EPOCHS = 5
|
|
self.TRAIN_BATCH_SIZE = 64
|
|
|
|
# model hyper-params
|
|
self.categories = 10
|
|
# self.learning_rate=0.001
|
|
# self.decay_rate=0.9
|
|
|
|
self.path_vocab_size = 27500
|
|
self.token_vocab_size = 1500
|
|
self.MAX_CONTEXTS = 200
|
|
self.MAX_TOKEN_VOCAB_SIZE = 1301136
|
|
self.MAX_PATH_VOCAB_SIZE = 911417
|
|
self.DEFAULT_EMBEDDINGS_SIZE = 64
|
|
self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
|
|
self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
|
|
self.CODE_VECTOR_SIZE = self.context_vector_size
|
|
self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
|
|
self.DROPOUT_KEEP_RATE = 0.5
|
|
|
|
@property
|
|
def context_vector_size(self) -> int:
|
|
# The context vector is actually a concatenation of the embedded
|
|
# source & target vectors and the embedded path vector.
|
|
return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE
|