# import re
|
|
# import torch
|
|
# import torch.nn as nn
|
|
# import jieba
|
|
# import pandas as pd
|
|
# # from torchtext import data
|
|
#
|
|
# class LSTMNet(nn.Module):
|
|
# def __init__(self,vocab_size,embedding_dim,hidden_dim,layer_dim,output_dim):
|
|
# super(LSTMNet,self).__init__()
|
|
# self.hidden_dim= hidden_dim
|
|
# self.layer_dim = layer_dim
|
|
# self.embedding = nn.Embedding(vocab_size,embedding_dim)
|
|
# # LSTM+全连接
|
|
# self.lstm = nn.LSTM(embedding_dim,hidden_dim,layer_dim,
|
|
# batch_first=True)
|
|
# self.fcl= nn.Linear(hidden_dim,output_dim)
|
|
# def forward(self,x):
|
|
# embeds = self.embedding(x)
|
|
# r_out,(h_n,h_c)=self.lstm(embeds,None)
|
|
# out = self.fcl(r_out[:,-1,:])
|
|
# return out
|
|
#
|
|
# def Chinese_pre(text_data,stopwords):
|
|
# # 字母转化为小写, 去掉数字
|
|
# text_data = text_data.lower()
|
|
# text_data = re.sub("\d+","",text_data)
|
|
# # 分词,使用精确模式
|
|
# text_data = list(jieba.cut(text_data,cut_all = False))
|
|
# # 去除停用词和多余空格
|
|
# text_data = [word.strip() for word in text_data if word not in stopwords]
|
|
# # 处理后的词语使用空格连接为字符串
|
|
# text_data = " ".join(text_data)
|
|
# return text_data
|
|
#
|
|
# def TexttoLable(textdata):
|
|
# # 将输入文本转为tensor
|
|
# # 首先对文本进行分词
|
|
# from nltk.corpus import stopwords
|
|
# import nltk
|
|
# nltk.download('stopwords')
|
|
# words = stopwords.words('english')
|
|
# stopwords = set()
|
|
# with open("stop.txt",encoding="utf-8") as infile:
|
|
# for line in infile:
|
|
# line = line.rstrip('\n')
|
|
# if line:
|
|
# stopwords.add(line.lower())
|
|
# for i in words:
|
|
# stopwords.add(i)
|
|
# textdata=Chinese_pre(textdata,stopwords)
|
|
#
|
|
# data1=[]
|
|
# for i in range(128):
|
|
# data1.append(textdata)
|
|
# df = pd.DataFrame({'cutword':data1})
|
|
#
|
|
# df.to_csv("tmp.csv")
|
|
#
|
|
# mytokenize = lambda x:x.split()
|
|
# from torchtext.legacy import data
|
|
# TEXT = data.Field(sequential = True,tokenize = mytokenize,
|
|
# include_lengths=True,use_vocab=True,
|
|
# batch_first=True,fix_length=40)
|
|
#
|
|
# LABEL = data.Field(sequential =False,use_vocab=False,
|
|
# pad_token=None,unk_token=None)
|
|
# # 对所有读取的数据集的列进行处理
|
|
# text_data_fields = [
|
|
# ("labelcode",LABEL),
|
|
# ("cutword",TEXT)
|
|
# ]
|
|
# # 读取数据
|
|
# # 读取数据
|
|
# traindata,valdata,testdata = data.TabularDataset.splits(
|
|
# path="./",format="csv",train="tmp.csv",fields = text_data_fields,
|
|
# validation = "tmp.csv",
|
|
# test ="tmp.csv",skip_header=True
|
|
# )
|
|
#
|
|
# em = testdata.examples[0]
|
|
# TEXT.build_vocab(traindata,max_size=100,vectors=None)
|
|
#
|
|
# # 定义一个迭代器,将类似长度的示例一起批处理
|
|
# BATCH_SIZE=128
|
|
# test_iter = data.BucketIterator(testdata,batch_size=128)
|
|
#
|
|
#
|
|
# vocab_size=len(TEXT.vocab)
|
|
# embedding_dim=50
|
|
# hidden_dim=256
|
|
# layer_dim=1
|
|
# output_dim=4
|
|
# lstmmodel = LSTMNet(vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim)
|
|
#
|
|
# res=0
|
|
# model = torch.load("model.pkl")
|
|
# for step,batch in enumerate(test_iter):
|
|
# textfinal = batch.cutword[0]
|
|
# out = model(textfinal)
|
|
# pre_lab = torch.argmax(out,1)
|
|
# res = pre_lab[0]
|
|
# print(res.numpy())
|
|
#
|
|
# TexttoLable("萝卜云服交流群等3个会话 ")
|