You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.2 KiB

пре 3 година
  1. import base64
  2. import urllib
  3. from uuid import uuid1
  4. import PIL.Image
  5. import torchvision.transforms as T
  6. import matplotlib
  7. matplotlib.use('Agg')
  8. from fastai.vision import *
  9. from fastai.vision import load_learner
  10. from core import FeatureLoss
  11. from typing import List
  12. from fastapi import FastAPI, File, UploadFile
  13. from fastapi.responses import FileResponse
  14. import uvicorn
  15. from starlette.middleware.cors import CORSMiddleware
  16. from fastapi.responses import StreamingResponse
  17. import base64
  18. learn = None
  19. if not (os.path.exists('./ArtLine_650.pkl')):
  20. MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1"
  21. urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl")
  22. # singleton start
  23. def load_pkl(self) -> Any:
  24. global learn
  25. path = Path(".")
  26. learn = load_learner(path, 'ArtLine_650.pkl')
  27. PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
  28. pl = PklLoader()
  29. pl.load_pkl()
  30. # singleton end
  31. def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None,
  32. hide_axis: bool = True,
  33. cmap: str = None, y: Any = None, out_file: str = None, **kwargs):
  34. "Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
  35. cmap = ifnone(cmap, defaults.cmap)
  36. ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
  37. if y is not None: y.show(ax=ax, **kwargs)
  38. if title is not None: ax.set_title(title)
  39. ax.get_figure().savefig('result/' + out_file)
  40. plt.close(ax.get_figure())
  41. print('close')
  42. app = FastAPI()
  43. #设置允许访问的域名
  44. origins = ["*"] #也可以设置为"*",即为所有。
  45. #设置跨域传参
  46. app.add_middleware(
  47. CORSMiddleware,
  48. allow_origins=origins, #设置允许的origins来源
  49. allow_credentials=True,
  50. allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
  51. allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
  52. @app.post("/uploadfiles/")
  53. async def create_upload_files(files: List[UploadFile] = File(...)):
  54. for file in files:
  55. content = await file.read()
  56. with open(r'tmp/'+file.filename, "wb") as f:
  57. f.write(content)
  58. return {"filenames": [file.filename for file in files]}
  59. def read_img_file_as_base64(local_file) -> str:
  60. with open(local_file, "rb") as rf:
  61. base64_str = base64.b64encode(rf.read())
  62. os.remove(local_file)
  63. return base64_str.decode()
  64. @app.get("/download/{file_name}")
  65. async def main(file_name):
  66. local_file = 'tmp/' + file_name
  67. file_path = 'result/' + file_name
  68. try:
  69. img = PIL.Image.open(local_file).convert('RGB')
  70. width, height = img.size
  71. print(width,height)
  72. img_t = T.ToTensor()(img)
  73. img_fast = Image(img_t)
  74. p, img_hr, b = learn.predict(img_fast)
  75. r = Image(img_hr)
  76. demo_show(r,figsize=(8,8), out_file=file_name)
  77. except Exception as e:
  78. print(e)
  79. finally:
  80. if os.path.exists(local_file):
  81. os.remove(local_file)
  82. return FileResponse(file_path, media_type="image/png")
  83. if __name__ == '__main__':
  84. uvicorn.run("app:app",host='0.0.0.0',port=8000)