NoteOnMe博客平台搭建
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

277 строки
10 KiB

  1. # coding=utf-8
  2. import os
  3. import shutil
  4. import sys
  5. import time
  6. import cv2
  7. import numpy as np
  8. import tensorflow as tf
  9. from main import preprocess
  10. import json
  11. import locale
  12. locale.setlocale(locale.LC_ALL, 'C')
  13. from scipy.misc import imread
  14. #current_directory = os.path.dirname(os.path.abspath(__file__))
  15. #root_path = os.path.abspath(os.path.dirname(current_directory) + os.path.sep + ".")
  16. #sys.path.append(sys.path.append(os.getcwd()))
  17. sys.path.append(os.getcwd())
  18. from nets import model_train as ctpnmodel
  19. from utils.rpn_msr.proposal_layer import proposal_layer
  20. from utils.text_connector.detectors import TextDetector
  21. from scipy.misc import imread
  22. import os
  23. from PIL import Image
  24. from model.img2seq import Img2SeqModel
  25. from model.utils.general import Config, run
  26. from model.utils.text import Vocab
  27. from model.utils.image import greyscale,predictsize
  28. tf.app.flags.DEFINE_string('test_data_path', '/app/image/1.png', '')
  29. tf.app.flags.DEFINE_string('output_path', '/app/im2latex_master/results/predict/', '')
  30. tf.app.flags.DEFINE_string('gpu', '0', '')
  31. tf.app.flags.DEFINE_string('checkpoint_path', '/app/im2latex_master/checkpoints_mlt/', '')
  32. FLAGS = tf.app.flags.FLAGS
  33. tf.app.flags.DEFINE_integer('language', '2', '')
  34. def get_images():
  35. files = []
  36. exts = ['jpg', 'png', 'jpeg', 'JPG']
  37. for parent, dirnames, filenames in os.walk(FLAGS.test_data_path):
  38. for filename in filenames:
  39. for ext in exts:
  40. if filename.endswith(ext):
  41. files.append(os.path.join(parent, filename))
  42. break
  43. print('Find {} images'.format(len(files)))
  44. return files
  45. def resize_image(img):
  46. img_size = img.shape
  47. im_size_min = np.min(img_size[0:2])
  48. im_size_max = np.max(img_size[0:2])
  49. im_scale = float(600) / float(im_size_min)
  50. if np.round(im_scale * im_size_max) > 1200:
  51. im_scale = float(1200) / float(im_size_max)
  52. new_h = int(img_size[0] * im_scale)
  53. new_w = int(img_size[1] * im_scale)
  54. new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16
  55. new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16
  56. re_im = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
  57. #cv2.imshow("ss",img)
  58. #cv2.waitKey(0)
  59. return re_im, (new_h / img_size[0], new_w / img_size[1])
  60. def get_box():
  61. if os.path.exists(FLAGS.output_path):
  62. shutil.rmtree(FLAGS.output_path)
  63. os.makedirs(FLAGS.output_path)
  64. os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
  65. input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
  66. input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
  67. global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
  68. bbox_pred, cls_pred, cls_prob = ctpnmodel.model(input_image,2)
  69. variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
  70. saver = tf.train.Saver(variable_averages.variables_to_restore())
  71. ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
  72. model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
  73. # print('Restore from {}'.format(model_path))
  74. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  75. saver.restore(sess, model_path)
  76. dir_output = "/app/im2latex_master/results/full/"
  77. config_vocab = Config(dir_output + "vocab.json")
  78. config_model = Config(dir_output + "model.json")
  79. vocab = Vocab(config_vocab)
  80. model = Img2SeqModel(config_model, dir_output, vocab)
  81. model.build_pred()
  82. model.restore_session(dir_output + "model.weights4/test-model.ckpt")
  83. # print(FLAGS.test_data_path)
  84. img = cv2.imread(FLAGS.test_data_path)[:, :, ::-1]
  85. h, w, c = img.shape
  86. if h > 121:
  87. approx, image, (rh, rw) = preprocess.draw_rec(img)
  88. img = preprocess.Perspective(image, approx)
  89. img = cv2.resize(img, None, None, fx=1.0 / rw, fy=1.0 / rh, interpolation=cv2.INTER_LINEAR)
  90. #cv2.imshow("Dd",img)
  91. #cv2.waitKey(0)
  92. img, (rh, rw) = resize_image(img)
  93. h, w, c = img.shape
  94. im_info = np.array([h, w, c]).reshape([1, 3])
  95. bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
  96. feed_dict={input_image: [img],
  97. input_im_info: im_info})
  98. textsegs, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info, img)
  99. scores = textsegs[:, 0:2] # 改
  100. textsegs = textsegs[:, 2:6] # 改
  101. textdetector = TextDetector(DETECT_MODE='H')
  102. boxes = textdetector.detect(textsegs, scores, img.shape[:2], img)
  103. boxes = np.array(boxes, dtype=np.int)
  104. image_box = sorted(boxes, key=(lambda x: (x[1] + x[3], x[0] + x[6])))
  105. for i, box in enumerate(image_box):
  106. cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
  107. thickness=2)
  108. img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
  109. cv2.imshow("ss",img)
  110. cv2.waitKey(0)
  111. return 0
  112. def save_to_file():
  113. if os.path.exists(FLAGS.output_path):
  114. shutil.rmtree(FLAGS.output_path)
  115. os.makedirs(FLAGS.output_path)
  116. os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
  117. input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
  118. input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
  119. global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
  120. bbox_pred, cls_pred, cls_prob = ctpnmodel.model(input_image,2.0)
  121. variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
  122. saver = tf.train.Saver(variable_averages.variables_to_restore())
  123. ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
  124. model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
  125. sess=tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  126. saver.restore(sess, model_path)
  127. dir_output = "/app/im2latex_master/results/full/"
  128. config_vocab = Config(dir_output + "vocab.json")
  129. config_model = Config(dir_output + "model.json")
  130. vocab = Vocab(config_vocab)
  131. #英文
  132. config_vocab_en = Config(dir_output + "vocabe.json")
  133. vocab_en = Vocab(config_vocab_en)
  134. model_en = Img2SeqModel(config_model, dir_output, vocab_en)
  135. model_en.build_pred()
  136. model_en.restore_session(dir_output + "model.weights_en/test-model.ckpt")
  137. #print(FLAGS.test_data_path)
  138. img = imread(FLAGS.test_data_path)
  139. h, w, c = img.shape
  140. res = ""
  141. if h>40:
  142. approx, image, (rh, rw) = preprocess.draw_rec(img)
  143. img = preprocess.Perspective(image, approx)
  144. img = cv2.resize(img, None, None, fx=1.0 / rw, fy=1.0 / rh, interpolation=cv2.INTER_LINEAR)
  145. img, (rh, rw) = resize_image(img)
  146. h, w, c = img.shape
  147. im_info = np.array([h, w, c]).reshape([1, 3])
  148. bbox_pred_val, cls_prob_val = sess.run([bbox_pred, cls_prob],
  149. feed_dict={input_image: [img],
  150. input_im_info: im_info})
  151. textsegs, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info,img)
  152. scores = textsegs[:, 0:2] # 改
  153. textsegs = textsegs[:, 2:6] # 改
  154. textdetector = TextDetector(DETECT_MODE='H')
  155. boxes = textdetector.detect(textsegs, scores, img.shape[:2],img)
  156. boxes = np.array(boxes, dtype=np.int)
  157. img2=img.copy()
  158. for i, box in enumerate(boxes):
  159. if box[8]==1:
  160. cv2.polylines(img2, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
  161. thickness=2)
  162. else:
  163. cv2.polylines(img2, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 0, 0),
  164. thickness=2)
  165. img2 = cv2.resize(img2, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
  166. #cv2.imshow("ss", img2)
  167. #cv2.waitKey(0)
  168. for i,b in enumerate(boxes):
  169. lan=b[8]
  170. box = boxes[i]
  171. img0 = img[min(box[1], box[3]) - 1:max(box[5], box[7]) + 1, min(box[0], box[2]) - 1:max(box[4], box[6]) + 1,
  172. ::-1]
  173. #cv2.imshow("ss",img0)
  174. #cv2.waitKey(0)
  175. """
  176. if lan == 2:
  177. img0 = predictsize(img0)
  178. #cv2.imshow("ss",img0)
  179. #cv2.waitKey(0)
  180. img0 = greyscale(img0)
  181. hyp = model.predict(img0)
  182. res = res + hyp[0] + "\n"
  183. model.logger.info(hyp[0])
  184. else:
  185. """
  186. img0 = predictsize(img0)
  187. #cv2.imshow("ss",img0)
  188. #cv2.waitKey(0)
  189. img0 = greyscale(img0)
  190. hyp = model_en.predict(img0)
  191. res = res + hyp[0] + "\n"
  192. model_en.logger.info(hyp[0])
  193. #hyp=pytesseract.image_to_string(img0)
  194. #res = res + hyp + "\n"
  195. #model.logger.info(hyp)
  196. res = json.dumps({"res": res})
  197. model_en.logger.info(res)
  198. else:
  199. #print(0)
  200. img = predictsize(img)
  201. img0 = greyscale(img)
  202. #cv2.imshow("ss", img0)
  203. #cv2.waitKey(0)
  204. hyps = model_en.predict(img0)
  205. res = res + hyps[0] + "\n"
  206. model_en.logger.info(hyps[0])
  207. res = json.dumps({"res": res})
  208. model_en.logger.info(res)
  209. return 0
  210. '''
  211. cv2.imwrite(os.path.join(FLAGS.output_path, str(i) +'.png'),img[min(box[1],box[3]):max(box[5],box[7]),min(box[0],box[2]) :max(box[4],box[6]), ::-1])
  212. cv2.polylines(img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0),
  213. thickness=2)
  214. img = cv2.resize(img, None, None, fx=1.0 / rh, fy=1.0 / rw, interpolation=cv2.INTER_LINEAR)
  215. cv2.imwrite(os.path.join(FLAGS.output_path, os.path.basename(im_fn)), img[:, :, ::-1])
  216. with open(os.path.join(FLAGS.output_path, os.path.splitext(os.path.basename(im_fn))[0]) + ".txt",
  217. "w") as f:
  218. for i, box in enumerate(boxes):
  219. line = ",".join(str(box[k]) for k in range(8))
  220. line += "," + str(scores[i]) + "\r\n"
  221. f.writelines(line)
  222. '''
  223. def main(argv=None):
  224. res=save_to_file()
  225. #res=get_box()
  226. return res
  227. if __name__ == '__main__':
  228. tf.app.run()