NoteOnMe博客平台搭建
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

277 lines
10 KiB

# coding=utf-8
import os
import shutil
import sys
import time
import cv2
import numpy as np
import tensorflow as tf
from main import preprocess
import json
import locale
locale.setlocale(locale.LC_ALL, 'C')
from scipy.misc import imread
#current_directory = os.path.dirname(os.path.abspath(__file__))
#root_path = os.path.abspath(os.path.dirname(current_directory) + os.path.sep + ".")
#sys.path.append(sys.path.append(os.getcwd()))
sys.path.append(os.getcwd())
from nets import model_train as ctpnmodel
from utils.rpn_msr.proposal_layer import proposal_layer
from utils.text_connector.detectors import TextDetector
from scipy.misc import imread
import os
from PIL import Image
from model.img2seq import Img2SeqModel
from model.utils.general import Config, run
from model.utils.text import Vocab
from model.utils.image import greyscale,predictsize
tf.app.flags.DEFINE_string('test_data_path', '/app/image/1.png', '')
tf.app.flags.DEFINE_string('output_path', '/app/im2latex_master/results/predict/', '')
tf.app.flags.DEFINE_string('gpu', '0', '')
tf.app.flags.DEFINE_string('checkpoint_path', '/app/im2latex_master/checkpoints_mlt/', '')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('language', '2', '')
def get_images():
files = []
exts = ['jpg', 'png', 'jpeg', 'JPG']
for parent, dirnames, filenames in os.walk(FLAGS.test_data_path):
for filename in filenames:
for ext in exts:
if filename.endswith(ext):
files.append(os.path.join(parent, filename))
break
print('Find {} images'.format(len(files)))
return files
def resize_image(img):
img_size = img.shape
im_size_min = np.min(img_size[0:2])
im_size_max = np.max(img_size[0:2])
im_scale = float(600) / float(im_size_min)
if np.round(im_scale * im_size_max) > 1200:
im_scale = float(1200) / float(im_size_max)
new_h = int(img_size[0] * im_scale)
new_w = int(img_size[1] * im_scale)
new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16
new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16
re_im = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
#cv2.imshow("ss",img)
#cv2.waitKey(0)
return re_im, (new_h / img_size[0], new_w / img_size[1])
def get_box():
if os.path.exists(FLAGS.output_path):
shutil.rmtree(FLAGS.output_path)
os.makedirs(FLAGS.output_path)
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
bbox_pred, cls_pred, cls_prob = ctpnmodel.model(input_image,2)
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
saver = tf.train.Saver(variable_averages.variables_to_restore())
ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
# print('Restore from {}'.format(model_path))
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver.restore(sess, model_path)
dir_output = "/app/im2latex_master/results/full/"
config_vocab = Config(dir_output + "vocab.json")
config_model = Config(dir_output + "model.json")
vocab = Vocab(config_vocab)
model = Img2SeqModel(config_model, dir_output, vocab)
model.build_pred()
model.restore_session(dir_output + "model.weights4/test-model.ckpt")
# print(FLAGS.test_data_path)
img = cv2.imread(FLAGS.test_data_path)[:, :, ::-1]
h, w, c = img.shape
if h > 121:
approx, image, (rh, rw) = preprocess.draw_rec(img)
img = preprocess.Perspective(image, approx)
img = cv2.resize(img, None, None, fx=1.0 / rw, fy=1.0 / rh, interpolation=cv2.INTER_LINEAR)
#cv2.imshow("Dd",img)
#cv2.waitKey(0)
img, (rh, rw) = resize_image(img)
h, w, c = img.shape
im_info = np.array([h, w, c]).reshape([1, 3])
bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
feed_dict={input_image: [img],
input_im_info: im_info})
textsegs, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info, img)
scores = textsegs[:, 0:2] # 改
textsegs = textsegs[:, 2:6] # 改
textdetector = TextDetector(DETECT_MODE='H')
boxes = textdetector.detect(textsegs, scores, img.shape[:2], img)
boxes = np.array(boxes, dtype=np.int)
image_box = sorted(boxes, key=(lambda x: (x[1] + x[3], x[0] + x[6])))
for i, box in enumerate(image_box):
cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
thickness=2)
img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
cv2.imshow("ss",img)
cv2.waitKey(0)
return 0
def save_to_file():
if os.path.exists(FLAGS.output_path):
shutil.rmtree(FLAGS.output_path)
os.makedirs(FLAGS.output_path)
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
bbox_pred, cls_pred, cls_prob = ctpnmodel.model(input_image,2.0)
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
saver = tf.train.Saver(variable_averages.variables_to_restore())
ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
sess=tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver.restore(sess, model_path)
dir_output = "/app/im2latex_master/results/full/"
config_vocab = Config(dir_output + "vocab.json")
config_model = Config(dir_output + "model.json")
vocab = Vocab(config_vocab)
#英文
config_vocab_en = Config(dir_output + "vocabe.json")
vocab_en = Vocab(config_vocab_en)
model_en = Img2SeqModel(config_model, dir_output, vocab_en)
model_en.build_pred()
model_en.restore_session(dir_output + "model.weights_en/test-model.ckpt")
#print(FLAGS.test_data_path)
img = imread(FLAGS.test_data_path)
h, w, c = img.shape
res = ""
if h>40:
approx, image, (rh, rw) = preprocess.draw_rec(img)
img = preprocess.Perspective(image, approx)
img = cv2.resize(img, None, None, fx=1.0 / rw, fy=1.0 / rh, interpolation=cv2.INTER_LINEAR)
img, (rh, rw) = resize_image(img)
h, w, c = img.shape
im_info = np.array([h, w, c]).reshape([1, 3])
bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
feed_dict={input_image: [img],
input_im_info: im_info})
textsegs, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info,img)
scores = textsegs[:, 0:2] # 改
textsegs = textsegs[:, 2:6] # 改
textdetector = TextDetector(DETECT_MODE='H')
boxes = textdetector.detect(textsegs, scores, img.shape[:2],img)
boxes = np.array(boxes, dtype=np.int)
img2=img.copy()
for i, box in enumerate(boxes):
if box[8]==1:
cv2.polylines(img2, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
thickness=2)
else:
cv2.polylines(img2, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 0, 0),
thickness=2)
img2 = cv2.resize(img2, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
#cv2.imshow("ss", img2)
#cv2.waitKey(0)
for i,b in enumerate(boxes):
lan=b[8]
box = boxes[i]
img0 = img[min(box[1], box[3]) - 1:max(box[5], box[7]) + 1, min(box[0], box[2]) - 1:max(box[4], box[6]) + 1,
::-1]
#cv2.imshow("ss",img0)
#cv2.waitKey(0)
"""
if lan == 2:
img0 = predictsize(img0)
#cv2.imshow("ss",img0)
#cv2.waitKey(0)
img0 = greyscale(img0)
hyp = model.predict(img0)
res = res + hyp[0] + "\n"
model.logger.info(hyp[0])
else:
"""
img0 = predictsize(img0)
#cv2.imshow("ss",img0)
#cv2.waitKey(0)
img0 = greyscale(img0)
hyp = model_en.predict(img0)
res = res + hyp[0] + "\n"
model_en.logger.info(hyp[0])
#hyp=pytesseract.image_to_string(img0)
#res = res + hyp + "\n"
#model.logger.info(hyp)
res = json.dumps({"res": res})
model_en.logger.info(res)
else:
#print(0)
img = predictsize(img)
img0 = greyscale(img)
#cv2.imshow("ss", img0)
#cv2.waitKey(0)
hyps = model_en.predict(img0)
res = res + hyps[0] + "\n"
model_en.logger.info(hyps[0])
res = json.dumps({"res": res})
model_en.logger.info(res)
return 0
'''
cv2.imwrite(os.path.join(FLAGS.output_path, str(i) +'.png'),img[min(box[1],box[3]):max(box[5],box[7]),min(box[0],box[2]) :max(box[4],box[6]), ::-1])
cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
thickness=2)
img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
cv2.imwrite(os.path.join(FLAGS.output_path, os.path.basename(im_fn)), img[:, :, ::-1])
with open(os.path.join(FLAGS.output_path, os.path.splitext(os.path.basename(im_fn))[0]) + ".txt",
"w") as f:
for i, box in enumerate(boxes):
line = ",".join(str(box[k]) for k in range(8))
line += "," + str(scores[i]) + "\r\n"
f.writelines(line)
'''
def main(argv=None):
res=save_to_file()
#res=get_box()
return res
if __name__ == '__main__':
tf.app.run()