where pure knowledge is acquired by just reading
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

43 lines
1.3 KiB

from sympy import im
from model import NERModel, SentTokenizer
import torch as tch
device = tch.device('cuda:0')
tokenizer = SentTokenizer().to(device)
model = NERModel().to(device)
print('loading model')
model.load_state_dict(tch.load('Saved/NER.model'))
print('ok')
if __name__ == '__main__':
print('='*20)
while True:
text = input('>>>')
if len(text) >= 512 or len(text) <= 0:
print('Sorry bro. I cannot do this. ')
continue
tokIdList, tokList, label = tokenizer(text)
print('\n')
print(tokIdList)
print(tokList)
prediction = model(tokIdList)
_prediction = (
+ 2 * (prediction[..., 2] > 1/3)
+ 1 * (prediction[..., 1] > 1/3)
+ 0 * (prediction[..., 0] > 1/3)
)
for i, pred in enumerate(_prediction[0].tolist()):
if pred == 2:
print(f'%-25s' % tokList[i], 'reliance',
prediction[..., i, 2].tolist())
if pred == 1:
print(f'%-25s' % tokList[i], 'concept',
prediction[..., i, 1].tolist())
if pred == 0:
print(f'%-25s' % tokList[i], 'nothing',
prediction[..., i, 0].tolist())