懒人记时 代码仓库
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
3.2 KiB

  1. # import re
  2. # import torch
  3. # import torch.nn as nn
  4. # import jieba
  5. # import pandas as pd
  6. # # from torchtext import data
  7. #
  8. # class LSTMNet(nn.Module):
  9. # def __init__(self,vocab_size,embedding_dim,hidden_dim,layer_dim,output_dim):
  10. # super(LSTMNet,self).__init__()
  11. # self.hidden_dim= hidden_dim
  12. # self.layer_dim = layer_dim
  13. # self.embedding = nn.Embedding(vocab_size,embedding_dim)
  14. # # LSTM+全连接
  15. # self.lstm = nn.LSTM(embedding_dim,hidden_dim,layer_dim,
  16. # batch_first=True)
  17. # self.fcl= nn.Linear(hidden_dim,output_dim)
  18. # def forward(self,x):
  19. # embeds = self.embedding(x)
  20. # r_out,(h_n,h_c)=self.lstm(embeds,None)
  21. # out = self.fcl(r_out[:,-1,:])
  22. # return out
  23. #
  24. # def Chinese_pre(text_data,stopwords):
  25. # # 字母转化为小写, 去掉数字
  26. # text_data = text_data.lower()
  27. # text_data = re.sub("\d+","",text_data)
  28. # # 分词,使用精确模式
  29. # text_data = list(jieba.cut(text_data,cut_all = False))
  30. # # 去除停用词和多余空格
  31. # text_data = [word.strip() for word in text_data if word not in stopwords]
  32. # # 处理后的词语使用空格连接为字符串
  33. # text_data = " ".join(text_data)
  34. # return text_data
  35. #
  36. # def TexttoLable(textdata):
  37. # # 将输入文本转为tensor
  38. # # 首先对文本进行分词
  39. # from nltk.corpus import stopwords
  40. # import nltk
  41. # nltk.download('stopwords')
  42. # words = stopwords.words('english')
  43. # stopwords = set()
  44. # with open("stop.txt",encoding="utf-8") as infile:
  45. # for line in infile:
  46. # line = line.rstrip('\n')
  47. # if line:
  48. # stopwords.add(line.lower())
  49. # for i in words:
  50. # stopwords.add(i)
  51. # textdata=Chinese_pre(textdata,stopwords)
  52. #
  53. # data1=[]
  54. # for i in range(128):
  55. # data1.append(textdata)
  56. # df = pd.DataFrame({'cutword':data1})
  57. #
  58. # df.to_csv("tmp.csv")
  59. #
  60. # mytokenize = lambda x:x.split()
  61. # from torchtext.legacy import data
  62. # TEXT = data.Field(sequential = True,tokenize = mytokenize,
  63. # include_lengths=True,use_vocab=True,
  64. # batch_first=True,fix_length=40)
  65. #
  66. # LABEL = data.Field(sequential =False,use_vocab=False,
  67. # pad_token=None,unk_token=None)
  68. # # 对所有读取的数据集的列进行处理
  69. # text_data_fields = [
  70. # ("labelcode",LABEL),
  71. # ("cutword",TEXT)
  72. # ]
  73. # # 读取数据
  74. # # 读取数据
  75. # traindata,valdata,testdata = data.TabularDataset.splits(
  76. # path="./",format="csv",train="tmp.csv",fields = text_data_fields,
  77. # validation = "tmp.csv",
  78. # test ="tmp.csv",skip_header=True
  79. # )
  80. #
  81. # em = testdata.examples[0]
  82. # TEXT.build_vocab(traindata,max_size=100,vectors=None)
  83. #
  84. # # 定义一个迭代器,将类似长度的示例一起批处理
  85. # BATCH_SIZE=128
  86. # test_iter = data.BucketIterator(testdata,batch_size=128)
  87. #
  88. #
  89. # vocab_size=len(TEXT.vocab)
  90. # embedding_dim=50
  91. # hidden_dim=256
  92. # layer_dim=1
  93. # output_dim=4
  94. # lstmmodel = LSTMNet(vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim)
  95. #
  96. # res=0
  97. # model = torch.load("model.pkl")
  98. # for step,batch in enumerate(test_iter):
  99. # textfinal = batch.cutword[0]
  100. # out = model(textfinal)
  101. # pre_lab = torch.argmax(out,1)
  102. # res = pre_lab[0]
  103. # print(res.numpy())
  104. #
  105. # TexttoLable("萝卜云服交流群等3个会话 ")