- import base64
- import urllib
- from uuid import uuid1
- import PIL.Image
- import torchvision.transforms as T
- import matplotlib
- matplotlib.use('Agg')
- from fastai.vision import *
- from fastai.vision import load_learner
- from core import FeatureLoss
- from typing import List
- from fastapi import FastAPI, File, UploadFile
- from fastapi.responses import FileResponse
- import uvicorn
- from starlette.middleware.cors import CORSMiddleware
- from fastapi.responses import StreamingResponse
- import base64
- learn = None
- if not (os.path.exists('./ArtLine_650.pkl')):
- MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1"
- urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl")
- # singleton start
- def load_pkl(self) -> Any:
- global learn
- path = Path(".")
- learn = load_learner(path, 'ArtLine_650.pkl')
- PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
- pl = PklLoader()
- pl.load_pkl()
- # singleton end
- def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None,
- hide_axis: bool = True,
- cmap: str = None, y: Any = None, out_file: str = None, **kwargs):
- "Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
- cmap = ifnone(cmap, defaults.cmap)
- ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
- if y is not None: y.show(ax=ax, **kwargs)
- if title is not None: ax.set_title(title)
- ax.get_figure().savefig('result/' + out_file)
- plt.close(ax.get_figure())
- print('close')
- app = FastAPI()
- #设置允许访问的域名
- origins = ["*"] #也可以设置为"*",即为所有。
- #设置跨域传参
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins, #设置允许的origins来源
- allow_credentials=True,
- allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
- allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
- @app.post("/uploadfiles/")
- async def create_upload_files(files: List[UploadFile] = File(...)):
- for file in files:
- content = await file.read()
- with open(r'tmp/'+file.filename, "wb") as f:
- f.write(content)
- return {"filenames": [file.filename for file in files]}
- def read_img_file_as_base64(local_file) -> str:
- with open(local_file, "rb") as rf:
- base64_str = base64.b64encode(rf.read())
- os.remove(local_file)
- return base64_str.decode()
- @app.get("/download/{file_name}")
- async def main(file_name):
- local_file = 'tmp/' + file_name
- file_path = 'result/' + file_name
- try:
- img = PIL.Image.open(local_file).convert('RGB')
- width, height = img.size
- print(width,height)
- img_t = T.ToTensor()(img)
- img_fast = Image(img_t)
- p, img_hr, b = learn.predict(img_fast)
- r = Image(img_hr)
- demo_show(r,figsize=(8,8), out_file=file_name)
- except Exception as e:
- print(e)
- finally:
- if os.path.exists(local_file):
- os.remove(local_file)
- return FileResponse(file_path, media_type="image/png")
- if __name__ == '__main__':
- uvicorn.run("app:app",host='',port=8000)