From e78f8e6946cced0d65c9dbc36ff765c22e18a171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=B0=E9=92=B0=E6=9D=B0?= <10185501406@stu.ecnu.edu.cn> Date: Fri, 15 Jan 2021 22:04:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20''?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __init__.py | 35 +++++++++ app.py | 133 ++++++++++++++++++++++++++++++++++ couplet.py | 73 +++++++++++++++++++ error.py | 35 +++++++++ model.py | 234 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ seq2seq.py | 157 ++++++++++++++++++++++++++++++++++++++++ session.py | 12 ++++ user.py | 69 ++++++++++++++++++ 8 files changed, 748 insertions(+) create mode 100644 __init__.py create mode 100644 app.py create mode 100644 couplet.py create mode 100644 error.py create mode 100644 model.py create mode 100644 seq2seq.py create mode 100644 session.py create mode 100644 user.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e62ccce --- /dev/null +++ b/__init__.py @@ -0,0 +1,35 @@ +from sqlalchemy import create_engine +from sqlalchemy import Column, Integer, TEXT,VARCHAR +from sqlalchemy.ext.declarative import declarative_base + +engine = create_engine('mysql+pymysql://root:yyj0010YYJ@10.23.174.207/cloud',encoding="utf-8",echo=True) +base = declarative_base() + + +class User(base): + __tablename__ = 'users' + user_id = Column('user_id', VARCHAR(20), primary_key=True) + password = Column('password', TEXT, nullable=False) + + +class predict(base): + __tablename__= 'predictcouplet' + up = Column('up',VARCHAR(20),primary_key=True) + down= Column('down',VARCHAR(20),primary_key=True) + +class train(base): + __tablename__= 'trainingcouplet' + up = Column('up',VARCHAR(20),primary_key=True) + down= Column('down',VARCHAR(20),primary_key=True) + + +class evaluate(base): + __tablename__= 'evaluatecouplet' + up = Column('up',VARCHAR(20),primary_key=True) + down= Column('down',VARCHAR(20),primary_key=True) + popular=Column('popular',Integer) + + +base.metadata.create_all(engine) # 创建表结构 + + diff --git a/app.py b/app.py new file mode 100644 index 0000000..43515f0 --- /dev/null +++ b/app.py @@ -0,0 +1,133 @@ +from flask import Flask +from flask import render_template +from flask import request +from flask import redirect,url_for +from code1.model import Model +import session +import socket +import random +from __init__ import predict,train,evaluate +#获取本机电脑名 + + +app = Flask(__name__) +import user +import couplet + +vocab_file = 'code1/data/vocabs' +model_dir = 'code1/output' + +m = Model( + None, None, None, None, vocab_file, + num_units=1024, layers=4, dropout=0.2, + batch_size=32, learning_rate=0.0001, + output_dir=model_dir, + restore_model=True, init_train=False, init_infer=True) + + +@app.route("/") +def index(): + return render_template("homepage.html",ip=ip) + +@app.route("/homepage",methods=['GET']) +def homepage(): + return render_template("homepage.html",ip=ip) + +@app.route("/game1",methods=['GET','post']) +def game1(): + if request.method == 'GET': + return render_template("game1.html",ip=ip) + else: + upper = request.form.get("upper", "") + input=upper + lower=m.infer(' '.join(input)) + lower = ''.join(lower.split(' ')) + return render_template("game1_result.html",lower=lower,upper=upper,ip=ip) + +@app.route("/judge",methods=['GET','POST']) +def judge(): + result= request.form.get("result", "") + upper = request.form.get("upper") + lower = request.form.get("lower") + u = couplet.Couplet() + if(int(result)==1): + u.to_evaluation(up=upper,down=lower) + else: + u.to_predict(up=upper,down=lower) + return render_template("game1_back.html",ip=ip) + +@app.route("/game2",methods=['GET','post']) +def game2(): + if request.method == 'GET': + u = couplet.Couplet() + pre = u.db_session.query(predict).all() + up=[] + for i in pre: + up.append(i.up) + up=up[random.randint(0, len(up)-1)] + return render_template("game2.html",up=up,ip=ip) + else: + up = request.form.get("up") + down= request.form.get("down") + u = couplet.Couplet() + u.to_evaluation(up=up,down=down) + return render_template("game2_back.html",ip=ip) + +@app.route("/game3",methods=['GET','post']) +def game3(): + if request.method == 'GET': + u = couplet.Couplet() + pre = u.db_session.query(evaluate).all() + up=[] + down=[] + for i in pre: + up.append(i.up) + down.append(i.down) + temp=random.randint(0,len(up)-1) + up=up[temp] + down=down[temp] + return render_template("game3.html",up=up,down=down,ip=ip) + else: + result = request.form.get("result", "") + up = request.form.get("up") + down= request.form.get("down") + print(result,up,down) + if(result=='1'): + u = couplet.Couplet() + u.add_popular(up=up, down=down,popular=1) + else: + u = couplet.Couplet() + u.minus_popular(up=up, down=down, popular=1) + return render_template("game3_back.html",ip=ip) + +@app.route("/register", methods=['GET',"POST"]) +def register(): + if request.method=='GET': + return render_template("register.html",ip=ip) + else: + user_id = request.form.get("user_id", "") + password = request.form.get("password", "") + u = user.Users() + code, message = u.register(user_id=user_id, password=password) + if code==200: + return """注册成功,点击登陆 """.format(ip) + else: + return """注册失败,点击重新注册 """.format(ip) + + +@app.route("/login", methods=["GET","POST"]) +def login(): + if request.method=='GET': + return render_template("login.html") + else: + user_id = request.form.get("user_id", "") + password = request.form.get("password", "") + u = user.Users() + code, message = u.login(user_id=user_id, password=password) + if code==200: + return redirect(url_for('homepage')) + else: + return """用户名或密码错误,点击重新登陆 """.format(ip) + +if __name__ == '__main__': + app.run("127.0.0.1",port=5000) diff --git a/couplet.py b/couplet.py new file mode 100644 index 0000000..e9dfc6d --- /dev/null +++ b/couplet.py @@ -0,0 +1,73 @@ +import session +from __init__ import predict,train,evaluate + +class Couplet(session.ORMsession): + + def __init__(self): + session.ORMsession.__init__(self) + + def send_up(self,up:str): + return 'asda' + + def to_train(self,up:str,down:str): + try: + couplet=train(up=up,down=down) + self.db_session.add(couplet) + self.db_session.commit() + except BaseException as e: + return 530, "{}".format(str(e)) + return 200 + + def to_predict(self,up:str,down:str): + try: + couplet = predict(up=up,down=down) + self.db_session.add(couplet) + self.db_session.commit() + except BaseException as e: + return 530, "{}".format(str(e)) + return 200 + + def to_evaluation(self,up:str,down:str): + try: + couplet = evaluate(up=up, down=down,popular=0) + self.db_session.add(couplet) + self.db_session.commit() + self.db_session.rollback() + except BaseException as e: + return 530, "{}".format(str(e)) + return 200 + + def add_popular(self,up:str,down:str,popular:int): + try: + self.db_session.query(evaluate).filter(evaluate.up==up,evaluate.down==down).update({'evaluate.popular':evaluate.popular+popular}) + self.db_session.commit() + except BaseException as e: + return 530, "{}".format(str(e)) + return 200 + + def minus_popular(self,up:str,down:str,popular:int): + try: + couplet = evaluate() + self.db_session.query(evaluate).filter(evaluate.up == up, evaluate.down == down).update( + {'evaluate.popular': evaluate.popular - popular}) + self.db_session.commit() + except BaseException as e: + return 530, "{}".format(str(e)) + return 200 + + def change_eva(self): + try: + self.db_session.query(evaluate).filter(evaluate.popular<=-10).delete() + row=self.db_session.query(evaluate).filter(evaluate.popular>=0).all() + up=[] + down=[] + for i in row: + up.append(i.up) + down.append() + self.db_session.query(evaluate).filter(evaluate.popular >=10).delete() + self.db_session.commit() + for i in range(0,len(up)): + couplet=train(up=up[i],down=down[i]) + self.db_session.add(couplet) + except BaseException as e: + return 530, "{}".format(str(e)) diff --git a/error.py b/error.py new file mode 100644 index 0000000..1c6bbed --- /dev/null +++ b/error.py @@ -0,0 +1,35 @@ +error_code = { + 401: "authorization fail.", + 511: "non exist user id {}", + 512: "exist user id {}", + 513: "non exist store id {}", + 514: "exist store id {}", + 515: "non exist book id {}", + 516: "exist book id {}", + 517: "stock level low, book id {}", + 518: "invalid order id {}", + 519: "not sufficient funds, order id {}", + 520: "non qualified book.", + 521: "", + 522: "", + 523: "", + 524: "", + 525: "", + 526: "", + 527: "", + 528: "", +} + + +def error_non_exist_user_id(user_id): + return 511, error_code[511].format(user_id) + + +def error_exist_user_id(user_id): + return 512, error_code[512].format(user_id) + +def error_authorization_fail(): + return 401, error_code[401] + +def error_and_message(code, message): + return code, message diff --git a/model.py b/model.py new file mode 100644 index 0000000..8321008 --- /dev/null +++ b/model.py @@ -0,0 +1,234 @@ +import tensorflow as tf +import seq2seq +import bleu +import reader +from os import path +import random + + +class Model(): + + def __init__(self, train_input_file, train_target_file, + test_input_file, test_target_file, vocab_file, + num_units, layers, dropout, + batch_size, learning_rate, output_dir, + save_step = 100, eval_step = 1000, + param_histogram=False, restore_model=False, + init_train=True, init_infer=False): + self.num_units = num_units + self.layers = layers + self.dropout = dropout + self.batch_size = batch_size + self.learning_rate = learning_rate + self.save_step = save_step + self.eval_step = eval_step + self.param_histogram = param_histogram + self.restore_model = restore_model + self.init_train = init_train + self.init_infer = init_infer + + if init_train: + self.train_reader = reader.SeqReader(train_input_file, + train_target_file, vocab_file, batch_size) + self.train_reader.start() + self.train_data = self.train_reader.read() + self.eval_reader = reader.SeqReader(test_input_file, test_target_file, + vocab_file, batch_size) + self.eval_reader.start() + self.eval_data = self.eval_reader.read() + + self.model_file = path.join(output_dir, 'model.ckpl') + self.log_writter = tf.summary.FileWriter(output_dir) + + if init_train: + self._init_train() + self._init_eval() + + if init_infer: + self.infer_vocabs = reader.read_vocab(vocab_file) + self.infer_vocab_indices = dict((c, i) for i, c in + enumerate(self.infer_vocabs)) + self._init_infer() + self.reload_infer_model() + + + def gpu_session_config(self): + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + return config + + + def _init_train(self): + self.train_graph = tf.Graph() + with self.train_graph.as_default(): + self.train_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) + self.train_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) + self.train_target_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) + self.train_target_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) + output = seq2seq.seq2seq(self.train_in_seq, self.train_in_seq_len, + self.train_target_seq, self.train_target_seq_len, + len(self.train_reader.vocabs), + self.num_units, self.layers, self.dropout) + self.train_output = tf.argmax(tf.nn.softmax(output), 2) + self.loss = seq2seq.seq_loss(output, self.train_target_seq, + self.train_target_seq_len) + params = tf.trainable_variables() + gradients = tf.gradients(self.loss, params) + clipped_gradients, _ = tf.clip_by_global_norm( + gradients, 0.5) + self.train_op = tf.train.AdamOptimizer( + learning_rate=self.learning_rate + ).apply_gradients(zip(clipped_gradients,params)) + if self.param_histogram: + for v in tf.trainable_variables(): + tf.summary.histogram('train_' + v.name, v) + tf.summary.scalar('loss', self.loss) + self.train_summary = tf.summary.merge_all() + self.train_init = tf.global_variables_initializer() + self.train_saver = tf.train.Saver() + self.train_session = tf.Session(graph=self.train_graph, + config=self.gpu_session_config()) + + + def _init_eval(self): + self.eval_graph = tf.Graph() + with self.eval_graph.as_default(): + self.eval_in_seq = tf.placeholder(tf.int32, shape=[self.batch_size, None]) + self.eval_in_seq_len = tf.placeholder(tf.int32, shape=[self.batch_size]) + self.eval_output = seq2seq.seq2seq(self.eval_in_seq, + self.eval_in_seq_len, None, None, + len(self.eval_reader.vocabs), + self.num_units, self.layers, self.dropout) + if self.param_histogram: + for v in tf.trainable_variables(): + tf.summary.histogram('eval_' + v.name, v) + self.eval_summary = tf.summary.merge_all() + self.eval_saver = tf.train.Saver() + self.eval_session = tf.Session(graph=self.eval_graph, + config=self.gpu_session_config()) + + + def _init_infer(self): + self.infer_graph = tf.Graph() + with self.infer_graph.as_default(): + self.infer_in_seq = tf.placeholder(tf.int32, shape=[1, None]) + self.infer_in_seq_len = tf.placeholder(tf.int32, shape=[1]) + self.infer_output = seq2seq.seq2seq(self.infer_in_seq, + self.infer_in_seq_len, None, None, + len(self.infer_vocabs), + self.num_units, self.layers, self.dropout) + self.infer_saver = tf.train.Saver() + self.infer_session = tf.Session(graph=self.infer_graph, + config=self.gpu_session_config()) + + + + def train(self, epochs, start=0): + if not self.init_train: + raise Exception('Train graph is not inited!') + with self.train_graph.as_default(): + if path.isfile(self.model_file + '.meta') and self.restore_model: + print("Reloading model file before training.") + self.train_saver.restore(self.train_session, self.model_file) + else: + self.train_session.run(self.train_init) + total_loss = 0 + for step in range(start, epochs): + data = next(self.train_data) + in_seq = data['in_seq'] + in_seq_len = data['in_seq_len'] + target_seq = data['target_seq'] + target_seq_len = data['target_seq_len'] + output, loss, train, summary = self.train_session.run( + [self.train_output, self.loss, self.train_op, self.train_summary], + feed_dict={ + self.train_in_seq: in_seq, + self.train_in_seq_len: in_seq_len, + self.train_target_seq: target_seq, + self.train_target_seq_len: target_seq_len}) + total_loss += loss + self.log_writter.add_summary(summary, step) + if step % self.save_step == 0: + self.train_saver.save(self.train_session, self.model_file) + print("Saving model. Step: %d, loss: %f" % (step, + total_loss / self.save_step)) + # print sample output + sid = random.randint(0, self.batch_size-1) + input_text = reader.decode_text(in_seq[sid], + self.eval_reader.vocabs) + output_text = reader.decode_text(output[sid], + self.train_reader.vocabs) + target_text = reader.decode_text(target_seq[sid], + self.train_reader.vocabs).split(' ')[1:] + target_text = ' '.join(target_text) + print('******************************') + print('src: ' + input_text) + print('output: ' + output_text) + print('target: ' + target_text) + if step % self.eval_step == 0: + bleu_score = self.eval(step) + print("Evaluate model. Step: %d, score: %f, loss: %f" % ( + step, bleu_score, total_loss / self.save_step)) + eval_summary = tf.Summary(value=[tf.Summary.Value( + tag='bleu', simple_value=bleu_score)]) + self.log_writter.add_summary(eval_summary, step) + if step % self.save_step == 0: + total_loss = 0 + + + def eval(self, train_step): + with self.eval_graph.as_default(): + self.eval_saver.restore(self.eval_session, self.model_file) + bleu_score = 0 + target_results = [] + output_results = [] + for step in range(0, self.eval_reader.data_size): + data = next(self.eval_data) + in_seq = data['in_seq'] + in_seq_len = data['in_seq_len'] + target_seq = data['target_seq'] + target_seq_len = data['target_seq_len'] + outputs = self.eval_session.run( + self.eval_output, + feed_dict={ + self.eval_in_seq: in_seq, + self.eval_in_seq_len: in_seq_len}) + for i in range(len(outputs)): + output = outputs[i] + target = target_seq[i] + output_text = reader.decode_text(output, + self.eval_reader.vocabs).split(' ') + target_text = reader.decode_text(target[1:], + self.eval_reader.vocabs).split(' ') + prob = int(self.eval_reader.data_size * self.batch_size / 10) + target_results.append([target_text]) + output_results.append(output_text) + if random.randint(1, prob) == 1: + print('====================') + input_text = reader.decode_text(in_seq[i], + self.eval_reader.vocabs) + print('src:' + input_text) + print('output: ' + ' '.join(output_text)) + print('target: ' + ' '.join(target_text)) + return bleu.compute_bleu(target_results, output_results)[0] * 100 + + + def reload_infer_model(self): + with self.infer_graph.as_default(): + self.infer_saver.restore(self.infer_session, self.model_file) + + + def infer(self, text): + if not self.init_infer: + raise Exception('Infer graph is not inited!') + with self.infer_graph.as_default(): + in_seq = reader.encode_text(text.split(' ') + ['',], + self.infer_vocab_indices) + in_seq_len = len(in_seq) + outputs = self.infer_session.run(self.infer_output, + feed_dict={ + self.infer_in_seq: [in_seq], + self.infer_in_seq_len: [in_seq_len]}) + output = outputs[0] + output_text = reader.decode_text(output, self.infer_vocabs) + return output_text diff --git a/seq2seq.py b/seq2seq.py new file mode 100644 index 0000000..988edcd --- /dev/null +++ b/seq2seq.py @@ -0,0 +1,157 @@ +import tensorflow as tf +from tensorflow.contrib import rnn + +from tensorflow.python.layers import core as layers_core + + +def getLayeredCell(layer_size, num_units, input_keep_prob, + output_keep_prob=1.0): + return rnn.MultiRNNCell([rnn.DropoutWrapper(rnn.BasicLSTMCell(num_units), + input_keep_prob, output_keep_prob) for i in range(layer_size)]) + + +def bi_encoder(embed_input, in_seq_len, num_units, layer_size, input_keep_prob): + # encode input into a vector + bi_layer_size = int(layer_size / 2) + encode_cell_fw = getLayeredCell(bi_layer_size, num_units, input_keep_prob) + encode_cell_bw = getLayeredCell(bi_layer_size, num_units, input_keep_prob) + bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn( + cell_fw=encode_cell_fw, + cell_bw=encode_cell_bw, + inputs=embed_input, + sequence_length=in_seq_len, + dtype=embed_input.dtype, + time_major=False) + + # concat encode output and state + encoder_output = tf.concat(bi_encoder_output, -1) + encoder_state = [] + for layer_id in range(bi_layer_size): + encoder_state.append(bi_encoder_state[0][layer_id]) + encoder_state.append(bi_encoder_state[1][layer_id]) + encoder_state = tuple(encoder_state) + return encoder_output, encoder_state + + +def attention_decoder_cell(encoder_output, in_seq_len, num_units, layer_size, + input_keep_prob): + attention_mechanim = tf.contrib.seq2seq.BahdanauAttention(num_units, + encoder_output, in_seq_len, normalize=True) + # attention_mechanim = tf.contrib.seq2seq.LuongAttention(num_units, + # encoder_output, in_seq_len, scale = True) + cell = getLayeredCell(layer_size, num_units, input_keep_prob) + cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanim, + attention_layer_size=num_units) + return cell + + +def decoder_projection(output, output_size): + return tf.layers.dense(output, output_size, activation=None, + use_bias=False, name='output_mlp') + + +def train_decoder(encoder_output, in_seq_len, target_seq, target_seq_len, + encoder_state, num_units, layers, embedding, output_size, + input_keep_prob, projection_layer): + decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, + layers, input_keep_prob) + batch_size = tf.shape(in_seq_len)[0] + init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( + cell_state=encoder_state) + helper = tf.contrib.seq2seq.TrainingHelper( + target_seq, target_seq_len, time_major=False) + decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, + init_state, output_layer=projection_layer) + outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, + maximum_iterations=100) + return outputs.rnn_output + + +def infer_decoder(encoder_output, in_seq_len, encoder_state, num_units, layers, + embedding, output_size, input_keep_prob, projection_layer): + decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, + layers, input_keep_prob) + + batch_size = tf.shape(in_seq_len)[0] + init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( + cell_state=encoder_state) + + # TODO: start tokens and end tokens are hard code + """ + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( + embedding, tf.fill([batch_size], 0), 1) + decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, + init_state, output_layer=projection_layer) + """ + + decoder = tf.contrib.seq2seq.BeamSearchDecoder( + cell=decoder_cell, + embedding=embedding, + start_tokens=tf.fill([batch_size], 0), + end_token=1, + initial_state=init_state, + beam_width=10, + output_layer=projection_layer, + length_penalty_weight=1.0) + + outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, + maximum_iterations=100) + return outputs.sample_id + + +def seq2seq(in_seq, in_seq_len, target_seq, target_seq_len, vocab_size, + num_units, layers, dropout): + in_shape = tf.shape(in_seq) + batch_size = in_shape[0] + + if target_seq != None: + input_keep_prob = 1 - dropout + else: + input_keep_prob = 1 + + projection_layer = layers_core.Dense(vocab_size, use_bias=False) + + # embedding input and target sequence + with tf.device('/cpu:0'): + embedding = tf.get_variable( + name='embedding', + shape=[vocab_size, num_units]) + embed_input = tf.nn.embedding_lookup(embedding, in_seq, name='embed_input') + + # encode and decode + encoder_output, encoder_state = bi_encoder(embed_input, in_seq_len, + num_units, layers, input_keep_prob) + + decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units, + layers, input_keep_prob) + batch_size = tf.shape(in_seq_len)[0] + init_state = decoder_cell.zero_state(batch_size, tf.float32).clone( + cell_state=encoder_state) + + if target_seq != None: + embed_target = tf.nn.embedding_lookup(embedding, target_seq, + name='embed_target') + helper = tf.contrib.seq2seq.TrainingHelper( + embed_target, target_seq_len, time_major=False) + else: + # TODO: start tokens and end tokens are hard code + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( + embedding, tf.fill([batch_size], 0), 1) + decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, + init_state, output_layer=projection_layer) + outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, + maximum_iterations=100) + if target_seq != None: + return outputs.rnn_output + else: + return outputs.sample_id + + +def seq_loss(output, target, seq_len): + target = target[:, 1:] + cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, + labels=target) + batch_size = tf.shape(target)[0] + loss_mask = tf.sequence_mask(seq_len, tf.shape(output)[1]) + cost = cost * tf.to_float(loss_mask) + return tf.reduce_sum(cost) / tf.to_float(batch_size) diff --git a/session.py b/session.py new file mode 100644 index 0000000..993a80a --- /dev/null +++ b/session.py @@ -0,0 +1,12 @@ +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine +from sqlalchemy import or_ +from sqlalchemy import Column,Integer,String + +class ORMsession: + def __init__(self): + engine = create_engine('mysql+pymysql://root:yyj0010YYJ@10.23.174.207/cloud',encoding="utf-8", echo=True) + Session=sessionmaker(engine) + self.db_session = Session() diff --git a/user.py b/user.py new file mode 100644 index 0000000..d7e7cb7 --- /dev/null +++ b/user.py @@ -0,0 +1,69 @@ +from __init__ import User +import logging +import session +import error +import random +import string +import jwt +import time + +#Users类,包括注册登录 +class Users(session.ORMsession): + token_lifetime: int = 60 + + def __init__(self): + session.ORMsession.__init__(self) + + def register(self,user_id:str,password:str) -> (int,str): + try: + if(len(password)==0): + return 202,"error" + user=User(user_id=user_id,password=password) + self.db_session.add(user) + self.db_session.commit() + except: + return error.error_exist_user_id(user_id) + return 200,"ok" + + def unregister(self, user_id: str, password: str) -> (int, str): + try: + code, message = self.check_password(user_id, password) + if code != 200: + return code, message + user = self.db_session.query(User).filter(User.user_id==user_id).first() + self.db_session.delete(user) + self.db_session.commit() + except: + return error.error_authorization_fail() + return 200, "ok" + + def check_password(self, user_id: str, password: str) -> (int, str): + user = self.db_session.query(User).filter(User.user_id==user_id).first() + if user is None: + return error.error_authorization_fail() + if password != user.password: + return error.error_authorization_fail() + return 200, "ok" + + def login(self, user_id: str, password: str) -> (int, str, str): + try: + code, message = self.check_password(user_id, password) + if code != 200: + return code, message + self.db_session.commit() + except: + return error.error_authorization_fail() + return 200, "ok" + + def change_password(self, user_id: str, old_password: str, new_password: str) -> (int,str): + try: + code, message = self.check_password(user_id, old_password) + if code != 200: + return code, message + self.db_session.query(User).filter(User.user_id== user_id)\ + .update({'password': new_password}) + self.db_session.commit() + except: + return error.error_authorization_fail() + return 200, "ok" +