from sympy import im
|
|
from model import NERModel, SentTokenizer
|
|
import torch as tch
|
|
|
|
device = tch.device('cuda:0')
|
|
|
|
tokenizer = SentTokenizer().to(device)
|
|
model = NERModel().to(device)
|
|
|
|
print('loading model')
|
|
model.load_state_dict(tch.load('Saved/NER.model'))
|
|
print('ok')
|
|
|
|
if __name__ == '__main__':
|
|
print('='*20)
|
|
while True:
|
|
text = input('>>>')
|
|
if len(text) >= 512 or len(text) <= 0:
|
|
print('Sorry bro. I cannot do this. ')
|
|
continue
|
|
|
|
tokIdList, tokList, label = tokenizer(text)
|
|
|
|
print('\n')
|
|
print(tokIdList)
|
|
print(tokList)
|
|
|
|
prediction = model(tokIdList)
|
|
_prediction = (
|
|
+ 2 * (prediction[..., 2] > 1/3)
|
|
+ 1 * (prediction[..., 1] > 1/3)
|
|
+ 0 * (prediction[..., 0] > 1/3)
|
|
)
|
|
for i, pred in enumerate(_prediction[0].tolist()):
|
|
if pred == 2:
|
|
print(f'%-25s' % tokList[i], 'reliance',
|
|
prediction[..., i, 2].tolist())
|
|
if pred == 1:
|
|
print(f'%-25s' % tokList[i], 'concept',
|
|
prediction[..., i, 1].tolist())
|
|
if pred == 0:
|
|
print(f'%-25s' % tokList[i], 'nothing',
|
|
prediction[..., i, 0].tolist())
|