diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e62ccce
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,35 @@
+from sqlalchemy import create_engine
+from sqlalchemy import Column, Integer, TEXT,VARCHAR
+from sqlalchemy.ext.declarative import declarative_base
+
+engine = create_engine('mysql+pymysql://root:yyj0010YYJ@10.23.174.207/cloud',encoding="utf-8",echo=True)
+base = declarative_base()
+
+
+class User(base):
+ __tablename__ = 'users'
+ user_id = Column('user_id', VARCHAR(20), primary_key=True)
+ password = Column('password', TEXT, nullable=False)
+
+
+class predict(base):
+ __tablename__= 'predictcouplet'
+ up = Column('up',VARCHAR(20),primary_key=True)
+ down= Column('down',VARCHAR(20),primary_key=True)
+
+class train(base):
+ __tablename__= 'trainingcouplet'
+ up = Column('up',VARCHAR(20),primary_key=True)
+ down= Column('down',VARCHAR(20),primary_key=True)
+
+
+class evaluate(base):
+ __tablename__= 'evaluatecouplet'
+ up = Column('up',VARCHAR(20),primary_key=True)
+ down= Column('down',VARCHAR(20),primary_key=True)
+ popular=Column('popular',Integer)
+
+
+base.metadata.create_all(engine) # 创建表结构
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..43515f0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,133 @@
+from flask import Flask
+from flask import render_template
+from flask import request
+from flask import redirect,url_for
+from code1.model import Model
+import session
+import socket
+import random
+from __init__ import predict,train,evaluate
+#获取本机电脑名
+
+
+app = Flask(__name__)
+import user
+import couplet
+
+vocab_file = 'code1/data/vocabs'
+model_dir = 'code1/output'
+
+m = Model(
+ None, None, None, None, vocab_file,
+ num_units=1024, layers=4, dropout=0.2,
+ batch_size=32, learning_rate=0.0001,
+ output_dir=model_dir,
+ restore_model=True, init_train=False, init_infer=True)
+
+
+@app.route("/")
+def index():
+ return render_template("homepage.html",ip=ip)
+
+@app.route("/homepage",methods=['GET'])
+def homepage():
+ return render_template("homepage.html",ip=ip)
+
+@app.route("/game1",methods=['GET','post'])
+def game1():
+ if request.method == 'GET':
+ return render_template("game1.html",ip=ip)
+ else:
+ upper = request.form.get("upper", "")
+ input=upper
+ lower=m.infer(' '.join(input))
+ lower = ''.join(lower.split(' '))
+ return render_template("game1_result.html",lower=lower,upper=upper,ip=ip)
+
+@app.route("/judge",methods=['GET','POST'])
+def judge():
+ result= request.form.get("result", "")
+ upper = request.form.get("upper")
+ lower = request.form.get("lower")
+ u = couplet.Couplet()
+ if(int(result)==1):
+ u.to_evaluation(up=upper,down=lower)
+ else:
+ u.to_predict(up=upper,down=lower)
+ return render_template("game1_back.html",ip=ip)
+
+@app.route("/game2",methods=['GET','post'])
+def game2():
+ if request.method == 'GET':
+ u = couplet.Couplet()
+ pre = u.db_session.query(predict).all()
+ up=[]
+ for i in pre:
+ up.append(i.up)
+ up=up[random.randint(0, len(up)-1)]
+ return render_template("game2.html",up=up,ip=ip)
+ else:
+ up = request.form.get("up")
+ down= request.form.get("down")
+ u = couplet.Couplet()
+ u.to_evaluation(up=up,down=down)
+ return render_template("game2_back.html",ip=ip)
+
+@app.route("/game3",methods=['GET','post'])
+def game3():
+ if request.method == 'GET':
+ u = couplet.Couplet()
+ pre = u.db_session.query(evaluate).all()
+ up=[]
+ down=[]
+ for i in pre:
+ up.append(i.up)
+ down.append(i.down)
+ temp=random.randint(0,len(up)-1)
+ up=up[temp]
+ down=down[temp]
+ return render_template("game3.html",up=up,down=down,ip=ip)
+ else:
+ result = request.form.get("result", "")
+ up = request.form.get("up")
+ down= request.form.get("down")
+ print(result,up,down)
+ if(result=='1'):
+ u = couplet.Couplet()
+ u.add_popular(up=up, down=down,popular=1)
+ else:
+ u = couplet.Couplet()
+ u.minus_popular(up=up, down=down, popular=1)
+ return render_template("game3_back.html",ip=ip)
+
+@app.route("/register", methods=['GET',"POST"])
+def register():
+ if request.method=='GET':
+ return render_template("register.html",ip=ip)
+ else:
+ user_id = request.form.get("user_id", "")
+ password = request.form.get("password", "")
+ u = user.Users()
+ code, message = u.register(user_id=user_id, password=password)
+ if code==200:
+ return """注册成功,点击登陆 """.format(ip)
+ else:
+ return """注册失败,点击重新注册 """.format(ip)
+
+
+@app.route("/login", methods=["GET","POST"])
+def login():
+ if request.method=='GET':
+ return render_template("login.html")
+ else:
+ user_id = request.form.get("user_id", "")
+ password = request.form.get("password", "")
+ u = user.Users()
+ code, message = u.login(user_id=user_id, password=password)
+ if code==200:
+ return redirect(url_for('homepage'))
+ else:
+ return """用户名或密码错误,点击重新登陆 """.format(ip)
+
+if __name__ == '__main__':
+ app.run("127.0.0.1",port=5000)
diff --git a/couplet.py b/couplet.py
new file mode 100644
index 0000000..e9dfc6d
--- /dev/null
+++ b/couplet.py
@@ -0,0 +1,73 @@
+import session
+from __init__ import predict,train,evaluate
+
+class Couplet(session.ORMsession):
+
+ def __init__(self):
+ session.ORMsession.__init__(self)
+
+ def send_up(self,up:str):
+ return 'asda'
+
+ def to_train(self,up:str,down:str):
+ try:
+ couplet=train(up=up,down=down)
+ self.db_session.add(couplet)
+ self.db_session.commit()
+ except BaseException as e:
+ return 530, "{}".format(str(e))
+ return 200
+
+ def to_predict(self,up:str,down:str):
+ try:
+ couplet = predict(up=up,down=down)
+ self.db_session.add(couplet)
+ self.db_session.commit()
+ except BaseException as e:
+ return 530, "{}".format(str(e))
+ return 200
+
+ def to_evaluation(self,up:str,down:str):
+ try:
+ couplet = evaluate(up=up, down=down,popular=0)
+ self.db_session.add(couplet)
+ self.db_session.commit()
+ self.db_session.rollback()
+ except BaseException as e:
+ return 530, "{}".format(str(e))
+ return 200
+
+ def add_popular(self,up:str,down:str,popular:int):
+ try:
+ self.db_session.query(evaluate).filter(evaluate.up==up,evaluate.down==down).update({'evaluate.popular':evaluate.popular+popular})
+ self.db_session.commit()
+ except BaseException as e:
+ return 530, "{}".format(str(e))
+ return 200
+
+ def minus_popular(self,up:str,down:str,popular:int):
+ try:
+ couplet = evaluate()
+ self.db_session.query(evaluate).filter(evaluate.up == up, evaluate.down == down).update(
+ {'evaluate.popular': evaluate.popular - popular})
+ self.db_session.commit()
+ except BaseException as e:
+ return 530, "{}".format(str(e))
+ return 200
+
+ def change_eva(self):
+ try:
+ self.db_session.query(evaluate).filter(evaluate.popular<=-10).delete()
+ row=self.db_session.query(evaluate).filter(evaluate.popular>=0).all()
+ up=[]
+ down=[]
+ for i in row:
+ up.append(i.up)
+ down.append()
+ self.db_session.query(evaluate).filter(evaluate.popular >=10).delete()
+ self.db_session.commit()
+ for i in range(0,len(up)):
+ couplet=train(up=up[i],down=down[i])
+ self.db_session.add(couplet)
+ except BaseException as e:
+ return 530, "{}".format(str(e))
diff --git a/error.py b/error.py
new file mode 100644
index 0000000..1c6bbed
--- /dev/null
+++ b/error.py
@@ -0,0 +1,35 @@
+error_code = {
+ 401: "authorization fail.",
+ 511: "non exist user id {}",
+ 512: "exist user id {}",
+ 513: "non exist store id {}",
+ 514: "exist store id {}",
+ 515: "non exist book id {}",
+ 516: "exist book id {}",
+ 517: "stock level low, book id {}",
+ 518: "invalid order id {}",
+ 519: "not sufficient funds, order id {}",
+ 520: "non qualified book.",
+ 521: "",
+ 522: "",
+ 523: "",
+ 524: "",
+ 525: "",
+ 526: "",
+ 527: "",
+ 528: "",
+}
+
+
+def error_non_exist_user_id(user_id):
+ return 511, error_code[511].format(user_id)
+
+
+def error_exist_user_id(user_id):
+ return 512, error_code[512].format(user_id)
+
+def error_authorization_fail():
+ return 401, error_code[401]
+
+def error_and_message(code, message):
+ return code, message
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..8321008
--- /dev/null
+++ b/model.py
@@ -0,0 +1,234 @@
+import tensorflow as tf
+import seq2seq
+import bleu
+import reader
+from os import path
+import random
+
+
+class Model():
+
+ def __init__(self, train_input_file, train_target_file,
+ test_input_file, test_target_file, vocab_file,
+ num_units, layers, dropout,
+ batch_size, learning_rate, output_dir,
+ save_step = 100, eval_step = 1000,
+ param_histogram=False, restore_model=False,
+ init_train=True, init_infer=False):
+ self.num_units = num_units
+ self.layers = layers
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.learning_rate = learning_rate
+ self.save_step = save_step
+ self.eval_step = eval_step
+ self.param_histogram = param_histogram
+ self.restore_model = restore_model
+ self.init_train = init_train
+ self.init_infer = init_infer
+
+ if init_train:
+ self.train_reader = reader.SeqReader(train_input_file,
+ train_target_file, vocab_file, batch_size)
+ self.train_reader.start()
+ self.train_data = self.train_reader.read()
+ self.eval_reader = reader.SeqReader(test_input_file, test_target_file,
+ vocab_file, batch_size)
+ self.eval_reader.start()
+ self.eval_data = self.eval_reader.read()
+
+ self.model_file = path.join(output_dir, 'model.ckpl')
+ self.log_writter = tf.summary.FileWriter(output_dir)
+
+ if init_train:
+ self._init_train()
+ self._init_eval()
+
+ if init_infer:
+ self.infer_vocabs = reader.read_vocab(vocab_file)
+ self.infer_vocab_indices = dict((c, i) for i, c in
+ enumerate(self.infer_vocabs))
+ self._init_infer()
+ self.reload_infer_model()
+
+
+ def gpu_session_config(self):
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ return config
+
+
+ def _init_train(self):
+ self.train_graph = tf.Graph()
+ with self.train_graph.as_default():
+ self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
+ self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
+ self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
+ self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
+ output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len,
+ self.train_target_seq, self.train_target_seq_len,
+ len(self.train_reader.vocabs),
+ self.num_units, self.layers, self.dropout)
+ self.train_output = tf.argmax(tf.nn.softmax(output), 2)
+ self.loss = seq2seq.seq_loss(output, self.train_target_seq,
+ self.train_target_seq_len)
+ params = tf.trainable_variables()
+ gradients = tf.gradients(self.loss, params)
+ clipped_gradients, _ = tf.clip_by_global_norm(
+ gradients, 0.5)
+ self.train_op = tf.train.AdamOptimizer(
+ learning_rate=self.learning_rate
+ ).apply_gradients(zip(clipped_gradients,params))
+ if self.param_histogram:
+ for v in tf.trainable_variables():
+ tf.summary.histogram('train_' + v.name, v)
+ tf.summary.scalar('loss', self.loss)
+ self.train_summary = tf.summary.merge_all()
+ self.train_init = tf.global_variables_initializer()
+ self.train_saver = tf.train.Saver()
+ self.train_session = tf.Session(graph=self.train_graph,
+ config=self.gpu_session_config())
+
+
+ def _init_eval(self):
+ self.eval_graph = tf.Graph()
+ with self.eval_graph.as_default():
+ self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None])
+ self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size])
+ self.eval_output = seq2seq.seq2seq(self.eval_in_seq,
+ self.eval_in_seq_len, None, None,
+ len(self.eval_reader.vocabs),
+ self.num_units, self.layers, self.dropout)
+ if self.param_histogram:
+ for v in tf.trainable_variables():
+ tf.summary.histogram('eval_' + v.name, v)
+ self.eval_summary = tf.summary.merge_all()
+ self.eval_saver = tf.train.Saver()
+ self.eval_session = tf.Session(graph=self.eval_graph,
+ config=self.gpu_session_config())
+
+
+ def _init_infer(self):
+ self.infer_graph = tf.Graph()
+ with self.infer_graph.as_default():
+ self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None])
+ self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1])
+ self.infer_output = seq2seq.seq2seq(self.infer_in_seq,
+ self.infer_in_seq_len, None, None,
+ len(self.infer_vocabs),
+ self.num_units, self.layers, self.dropout)
+ self.infer_saver = tf.train.Saver()
+ self.infer_session = tf.Session(graph=self.infer_graph,
+ config=self.gpu_session_config())
+
+
+
+ def train(self, epochs, start=0):
+ if not self.init_train:
+ raise Exception('Train graph is not inited!')
+ with self.train_graph.as_default():
+ if path.isfile(self.model_file + '.meta') and self.restore_model:
+ print("Reloading model file before training.")
+ self.train_saver.restore(self.train_session, self.model_file)
+ else:
+ self.train_session.run(self.train_init)
+ total_loss = 0
+ for step in range(start, epochs):
+ data = next(self.train_data)
+ in_seq = data['in_seq']
+ in_seq_len = data['in_seq_len']
+ target_seq = data['target_seq']
+ target_seq_len = data['target_seq_len']
+ output, loss, train, summary = self.train_session.run(
+ [self.train_output, self.loss, self.train_op, self.train_summary],
+ feed_dict={
+ self.train_in_seq: in_seq,
+ self.train_in_seq_len: in_seq_len,
+ self.train_target_seq: target_seq,
+ self.train_target_seq_len: target_seq_len})
+ total_loss += loss
+ self.log_writter.add_summary(summary, step)
+ if step % self.save_step == 0:
+ self.train_saver.save(self.train_session, self.model_file)
+ print("Saving model. Step: %d, loss: %f" % (step,
+ total_loss / self.save_step))
+ # print sample output
+ sid = random.randint(0, self.batch_size-1)
+ input_text = reader.decode_text(in_seq[sid],
+ self.eval_reader.vocabs)
+ output_text = reader.decode_text(output[sid],
+ self.train_reader.vocabs)
+ target_text = reader.decode_text(target_seq[sid],
+ self.train_reader.vocabs).split(' ')[1:]
+ target_text = ' '.join(target_text)
+ print('******************************')
+ print('src: ' + input_text)
+ print('output: ' + output_text)
+ print('target: ' + target_text)
+ if step % self.eval_step == 0:
+ bleu_score = self.eval(step)
+ print("Evaluate model. Step: %d, score: %f, loss: %f" % (
+ step, bleu_score, total_loss / self.save_step))
+ eval_summary = tf.Summary(value=[tf.Summary.Value(
+ tag='bleu', simple_value=bleu_score)])
+ self.log_writter.add_summary(eval_summary, step)
+ if step % self.save_step == 0:
+ total_loss = 0
+
+
+ def eval(self, train_step):
+ with self.eval_graph.as_default():
+ self.eval_saver.restore(self.eval_session, self.model_file)
+ bleu_score = 0
+ target_results = []
+ output_results = []
+ for step in range(0, self.eval_reader.data_size):
+ data = next(self.eval_data)
+ in_seq = data['in_seq']
+ in_seq_len = data['in_seq_len']
+ target_seq = data['target_seq']
+ target_seq_len = data['target_seq_len']
+ outputs = self.eval_session.run(
+ self.eval_output,
+ feed_dict={
+ self.eval_in_seq: in_seq,
+ self.eval_in_seq_len: in_seq_len})
+ for i in range(len(outputs)):
+ output = outputs[i]
+ target = target_seq[i]
+ output_text = reader.decode_text(output,
+ self.eval_reader.vocabs).split(' ')
+ target_text = reader.decode_text(target[1:],
+ self.eval_reader.vocabs).split(' ')
+ prob = int(self.eval_reader.data_size * self.batch_size / 10)
+ target_results.append([target_text])
+ output_results.append(output_text)
+ if random.randint(1, prob) == 1:
+ print('====================')
+ input_text = reader.decode_text(in_seq[i],
+ self.eval_reader.vocabs)
+ print('src:' + input_text)
+ print('output: ' + ' '.join(output_text))
+ print('target: ' + ' '.join(target_text))
+ return bleu.compute_bleu(target_results, output_results)[0] * 100
+
+
+ def reload_infer_model(self):
+ with self.infer_graph.as_default():
+ self.infer_saver.restore(self.infer_session, self.model_file)
+
+
+ def infer(self, text):
+ if not self.init_infer:
+ raise Exception('Infer graph is not inited!')
+ with self.infer_graph.as_default():
+ in_seq = reader.encode_text(text.split(' ') + ['',],
+ self.infer_vocab_indices)
+ in_seq_len = len(in_seq)
+ outputs = self.infer_session.run(self.infer_output,
+ feed_dict={
+ self.infer_in_seq: [in_seq],
+ self.infer_in_seq_len: [in_seq_len]})
+ output = outputs[0]
+ output_text = reader.decode_text(output, self.infer_vocabs)
+ return output_text
diff --git a/seq2seq.py b/seq2seq.py
new file mode 100644
index 0000000..988edcd
--- /dev/null
+++ b/seq2seq.py
@@ -0,0 +1,157 @@
+import tensorflow as tf
+from tensorflow.contrib import rnn
+
+from tensorflow.python.layers import core as layers_core
+
+
+def getLayeredCell(layer_size, num_units, input_keep_prob,
+ output_keep_prob=1.0):
+ return rnn.MultiRNNCell([rnn.DropoutWrapper(rnn.BasicLSTMCell(num_units),
+ input_keep_prob, output_keep_prob) for i in range(layer_size)])
+
+
+def bi_encoder(embed_input, in_seq_len, num_units, layer_size, input_keep_prob):
+ # encode input into a vector
+ bi_layer_size = int(layer_size / 2)
+ encode_cell_fw = getLayeredCell(bi_layer_size, num_units, input_keep_prob)
+ encode_cell_bw = getLayeredCell(bi_layer_size, num_units, input_keep_prob)
+ bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
+ cell_fw=encode_cell_fw,
+ cell_bw=encode_cell_bw,
+ inputs=embed_input,
+ sequence_length=in_seq_len,
+ dtype=embed_input.dtype,
+ time_major=False)
+
+ # concat encode output and state
+ encoder_output = tf.concat(bi_encoder_output, -1)
+ encoder_state = []
+ for layer_id in range(bi_layer_size):
+ encoder_state.append(bi_encoder_state[0][layer_id])
+ encoder_state.append(bi_encoder_state[1][layer_id])
+ encoder_state = tuple(encoder_state)
+ return encoder_output, encoder_state
+
+
+def attention_decoder_cell(encoder_output, in_seq_len, num_units, layer_size,
+ input_keep_prob):
+ attention_mechanim = tf.contrib.seq2seq.BahdanauAttention(num_units,
+ encoder_output, in_seq_len, normalize=True)
+ # attention_mechanim = tf.contrib.seq2seq.LuongAttention(num_units,
+ # encoder_output, in_seq_len, scale = True)
+ cell = getLayeredCell(layer_size, num_units, input_keep_prob)
+ cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanim,
+ attention_layer_size=num_units)
+ return cell
+
+
+def decoder_projection(output, output_size):
+ return tf.layers.dense(output, output_size, activation=None,
+ use_bias=False, name='output_mlp')
+
+
+def train_decoder(encoder_output, in_seq_len, target_seq, target_seq_len,
+ encoder_state, num_units, layers, embedding, output_size,
+ input_keep_prob, projection_layer):
+ decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units,
+ layers, input_keep_prob)
+ batch_size = tf.shape(in_seq_len)[0]
+ init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(
+ cell_state=encoder_state)
+ helper = tf.contrib.seq2seq.TrainingHelper(
+ target_seq, target_seq_len, time_major=False)
+ decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
+ init_state, output_layer=projection_layer)
+ outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
+ maximum_iterations=100)
+ return outputs.rnn_output
+
+
+def infer_decoder(encoder_output, in_seq_len, encoder_state, num_units, layers,
+ embedding, output_size, input_keep_prob, projection_layer):
+ decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units,
+ layers, input_keep_prob)
+
+ batch_size = tf.shape(in_seq_len)[0]
+ init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(
+ cell_state=encoder_state)
+
+ # TODO: start tokens and end tokens are hard code
+ """
+ helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
+ embedding, tf.fill([batch_size], 0), 1)
+ decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
+ init_state, output_layer=projection_layer)
+ """
+
+ decoder = tf.contrib.seq2seq.BeamSearchDecoder(
+ cell=decoder_cell,
+ embedding=embedding,
+ start_tokens=tf.fill([batch_size], 0),
+ end_token=1,
+ initial_state=init_state,
+ beam_width=10,
+ output_layer=projection_layer,
+ length_penalty_weight=1.0)
+
+ outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
+ maximum_iterations=100)
+ return outputs.sample_id
+
+
+def seq2seq(in_seq, in_seq_len, target_seq, target_seq_len, vocab_size,
+ num_units, layers, dropout):
+ in_shape = tf.shape(in_seq)
+ batch_size = in_shape[0]
+
+ if target_seq != None:
+ input_keep_prob = 1 - dropout
+ else:
+ input_keep_prob = 1
+
+ projection_layer = layers_core.Dense(vocab_size, use_bias=False)
+
+ # embedding input and target sequence
+ with tf.device('/cpu:0'):
+ embedding = tf.get_variable(
+ name='embedding',
+ shape=[vocab_size, num_units])
+ embed_input = tf.nn.embedding_lookup(embedding, in_seq, name='embed_input')
+
+ # encode and decode
+ encoder_output, encoder_state = bi_encoder(embed_input, in_seq_len,
+ num_units, layers, input_keep_prob)
+
+ decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units,
+ layers, input_keep_prob)
+ batch_size = tf.shape(in_seq_len)[0]
+ init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(
+ cell_state=encoder_state)
+
+ if target_seq != None:
+ embed_target = tf.nn.embedding_lookup(embedding, target_seq,
+ name='embed_target')
+ helper = tf.contrib.seq2seq.TrainingHelper(
+ embed_target, target_seq_len, time_major=False)
+ else:
+ # TODO: start tokens and end tokens are hard code
+ helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
+ embedding, tf.fill([batch_size], 0), 1)
+ decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
+ init_state, output_layer=projection_layer)
+ outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
+ maximum_iterations=100)
+ if target_seq != None:
+ return outputs.rnn_output
+ else:
+ return outputs.sample_id
+
+
+def seq_loss(output, target, seq_len):
+ target = target[:, 1:]
+ cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output,
+ labels=target)
+ batch_size = tf.shape(target)[0]
+ loss_mask = tf.sequence_mask(seq_len, tf.shape(output)[1])
+ cost = cost * tf.to_float(loss_mask)
+ return tf.reduce_sum(cost) / tf.to_float(batch_size)
diff --git a/session.py b/session.py
new file mode 100644
index 0000000..993a80a
--- /dev/null
+++ b/session.py
@@ -0,0 +1,12 @@
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy import create_engine
+from sqlalchemy import or_
+from sqlalchemy import Column,Integer,String
+
+class ORMsession:
+ def __init__(self):
+ engine = create_engine('mysql+pymysql://root:yyj0010YYJ@10.23.174.207/cloud',encoding="utf-8", echo=True)
+ Session=sessionmaker(engine)
+ self.db_session = Session()
diff --git a/user.py b/user.py
new file mode 100644
index 0000000..d7e7cb7
--- /dev/null
+++ b/user.py
@@ -0,0 +1,69 @@
+from __init__ import User
+import logging
+import session
+import error
+import random
+import string
+import jwt
+import time
+
+#Users类,包括注册登录
+class Users(session.ORMsession):
+ token_lifetime: int = 60
+
+ def __init__(self):
+ session.ORMsession.__init__(self)
+
+ def register(self,user_id:str,password:str) -> (int,str):
+ try:
+ if(len(password)==0):
+ return 202,"error"
+ user=User(user_id=user_id,password=password)
+ self.db_session.add(user)
+ self.db_session.commit()
+ except:
+ return error.error_exist_user_id(user_id)
+ return 200,"ok"
+
+ def unregister(self, user_id: str, password: str) -> (int, str):
+ try:
+ code, message = self.check_password(user_id, password)
+ if code != 200:
+ return code, message
+ user = self.db_session.query(User).filter(User.user_id==user_id).first()
+ self.db_session.delete(user)
+ self.db_session.commit()
+ except:
+ return error.error_authorization_fail()
+ return 200, "ok"
+
+ def check_password(self, user_id: str, password: str) -> (int, str):
+ user = self.db_session.query(User).filter(User.user_id==user_id).first()
+ if user is None:
+ return error.error_authorization_fail()
+ if password != user.password:
+ return error.error_authorization_fail()
+ return 200, "ok"
+
+ def login(self, user_id: str, password: str) -> (int, str, str):
+ try:
+ code, message = self.check_password(user_id, password)
+ if code != 200:
+ return code, message
+ self.db_session.commit()
+ except:
+ return error.error_authorization_fail()
+ return 200, "ok"
+
+ def change_password(self, user_id: str, old_password: str, new_password: str) -> (int,str):
+ try:
+ code, message = self.check_password(user_id, old_password)
+ if code != 200:
+ return code, message
+ self.db_session.query(User).filter(User.user_id== user_id)\
+ .update({'password': new_password})
+ self.db_session.commit()
+ except:
+ return error.error_authorization_fail()
+ return 200, "ok"
+