25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.

112 satır
3.2 KiB

3 yıl önce
  1. import base64
  2. import urllib
  3. from uuid import uuid1
  4. import PIL.Image
  5. import torchvision.transforms as T
  6. import matplotlib
  7. matplotlib.use('Agg')
  8. from fastai.vision import *
  9. from fastai.vision import load_learner
  10. from core import FeatureLoss
  11. from typing import List
  12. from fastapi import FastAPI, File, UploadFile
  13. from fastapi.responses import FileResponse
  14. import uvicorn
  15. from starlette.middleware.cors import CORSMiddleware
  16. from fastapi.responses import StreamingResponse
  17. import base64
  18. learn = None
  19. if not (os.path.exists('./ArtLine_650.pkl')):
  20. MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1"
  21. urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl")
  22. # singleton start
  23. def load_pkl(self) -> Any:
  24. global learn
  25. path = Path(".")
  26. learn = load_learner(path, 'ArtLine_650.pkl')
  27. PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
  28. pl = PklLoader()
  29. pl.load_pkl()
  30. # singleton end
  31. def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None,
  32. hide_axis: bool = True,
  33. cmap: str = None, y: Any = None, out_file: str = None, **kwargs):
  34. "Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
  35. cmap = ifnone(cmap, defaults.cmap)
  36. ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
  37. if y is not None: y.show(ax=ax, **kwargs)
  38. if title is not None: ax.set_title(title)
  39. ax.get_figure().savefig('result/' + out_file)
  40. plt.close(ax.get_figure())
  41. print('close')
  42. app = FastAPI()
  43. #设置允许访问的域名
  44. origins = ["*"] #也可以设置为"*",即为所有。
  45. #设置跨域传参
  46. app.add_middleware(
  47. CORSMiddleware,
  48. allow_origins=origins, #设置允许的origins来源
  49. allow_credentials=True,
  50. allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
  51. allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
  52. @app.post("/uploadfiles/")
  53. async def create_upload_files(files: List[UploadFile] = File(...)):
  54. for file in files:
  55. content = await file.read()
  56. with open(r'tmp/'+file.filename, "wb") as f:
  57. f.write(content)
  58. return {"filenames": [file.filename for file in files]}
  59. def read_img_file_as_base64(local_file) -> str:
  60. with open(local_file, "rb") as rf:
  61. base64_str = base64.b64encode(rf.read())
  62. os.remove(local_file)
  63. return base64_str.decode()
  64. @app.get("/download/{file_name}")
  65. async def main(file_name):
  66. local_file = 'tmp/' + file_name
  67. file_path = 'result/' + file_name
  68. try:
  69. img = PIL.Image.open(local_file).convert('RGB')
  70. width, height = img.size
  71. print(width,height)
  72. img_t = T.ToTensor()(img)
  73. img_fast = Image(img_t)
  74. p, img_hr, b = learn.predict(img_fast)
  75. r = Image(img_hr)
  76. demo_show(r,figsize=(8,8), out_file=file_name)
  77. except Exception as e:
  78. print(e)
  79. finally:
  80. if os.path.exists(local_file):
  81. os.remove(local_file)
  82. return FileResponse(file_path, media_type="image/png")
  83. if __name__ == '__main__':
  84. uvicorn.run("app:app",host='0.0.0.0',port=8000)