懒人记时 代码仓库
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

105 lines
3.2 KiB

# import re
# import torch
# import torch.nn as nn
# import jieba
# import pandas as pd
# # from torchtext import data
#
# class LSTMNet(nn.Module):
# def __init__(self,vocab_size,embedding_dim,hidden_dim,layer_dim,output_dim):
# super(LSTMNet,self).__init__()
# self.hidden_dim= hidden_dim
# self.layer_dim = layer_dim
# self.embedding = nn.Embedding(vocab_size,embedding_dim)
# # LSTM+全连接
# self.lstm = nn.LSTM(embedding_dim,hidden_dim,layer_dim,
# batch_first=True)
# self.fcl= nn.Linear(hidden_dim,output_dim)
# def forward(self,x):
# embeds = self.embedding(x)
# r_out,(h_n,h_c)=self.lstm(embeds,None)
# out = self.fcl(r_out[:,-1,:])
# return out
#
# def Chinese_pre(text_data,stopwords):
# # 字母转化为小写, 去掉数字
# text_data = text_data.lower()
# text_data = re.sub("\d+","",text_data)
# # 分词,使用精确模式
# text_data = list(jieba.cut(text_data,cut_all = False))
# # 去除停用词和多余空格
# text_data = [word.strip() for word in text_data if word not in stopwords]
# # 处理后的词语使用空格连接为字符串
# text_data = " ".join(text_data)
# return text_data
#
# def TexttoLable(textdata):
# # 将输入文本转为tensor
# # 首先对文本进行分词
# from nltk.corpus import stopwords
# import nltk
# nltk.download('stopwords')
# words = stopwords.words('english')
# stopwords = set()
# with open("stop.txt",encoding="utf-8") as infile:
# for line in infile:
# line = line.rstrip('\n')
# if line:
# stopwords.add(line.lower())
# for i in words:
# stopwords.add(i)
# textdata=Chinese_pre(textdata,stopwords)
#
# data1=[]
# for i in range(128):
# data1.append(textdata)
# df = pd.DataFrame({'cutword':data1})
#
# df.to_csv("tmp.csv")
#
# mytokenize = lambda x:x.split()
# from torchtext.legacy import data
# TEXT = data.Field(sequential = True,tokenize = mytokenize,
# include_lengths=True,use_vocab=True,
# batch_first=True,fix_length=40)
#
# LABEL = data.Field(sequential =False,use_vocab=False,
# pad_token=None,unk_token=None)
# # 对所有读取的数据集的列进行处理
# text_data_fields = [
# ("labelcode",LABEL),
# ("cutword",TEXT)
# ]
# # 读取数据
# # 读取数据
# traindata,valdata,testdata = data.TabularDataset.splits(
# path="./",format="csv",train="tmp.csv",fields = text_data_fields,
# validation = "tmp.csv",
# test ="tmp.csv",skip_header=True
# )
#
# em = testdata.examples[0]
# TEXT.build_vocab(traindata,max_size=100,vectors=None)
#
# # 定义一个迭代器,将类似长度的示例一起批处理
# BATCH_SIZE=128
# test_iter = data.BucketIterator(testdata,batch_size=128)
#
#
# vocab_size=len(TEXT.vocab)
# embedding_dim=50
# hidden_dim=256
# layer_dim=1
# output_dim=4
# lstmmodel = LSTMNet(vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim)
#
# res=0
# model = torch.load("model.pkl")
# for step,batch in enumerate(test_iter):
# textfinal = batch.cutword[0]
# out = model(textfinal)
# pre_lab = torch.argmax(out,1)
# res = pre_lab[0]
# print(res.numpy())
#
# TexttoLable("萝卜云服交流群等3个会话 ")