import base64 import urllib from uuid import uuid1 import PIL.Image import torchvision.transforms as T import matplotlib matplotlib.use('Agg') from fastai.vision import * from fastai.vision import load_learner from core import FeatureLoss from typing import List from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse import uvicorn from starlette.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse import base64 learn = None if not (os.path.exists('./ArtLine_650.pkl')): MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1" urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl") # singleton start def load_pkl(self) -> Any: global learn path = Path(".") learn = load_learner(path, 'ArtLine_650.pkl') PklLoader = type('PklLoader', (), {"load_pkl": load_pkl}) pl = PklLoader() pl.load_pkl() # singleton end def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None, hide_axis: bool = True, cmap: str = None, y: Any = None, out_file: str = None, **kwargs): "Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`" cmap = ifnone(cmap, defaults.cmap) ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize) if y is not None: y.show(ax=ax, **kwargs) if title is not None: ax.set_title(title) ax.get_figure().savefig('result/' + out_file) plt.close(ax.get_figure()) print('close') app = FastAPI() #设置允许访问的域名 origins = ["*"] #也可以设置为"*",即为所有。 #设置跨域传参 app.add_middleware( CORSMiddleware, allow_origins=origins, #设置允许的origins来源 allow_credentials=True, allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。 allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。 @app.post("/uploadfiles/") async def create_upload_files(files: List[UploadFile] = File(...)): for file in files: content = await file.read() with open(r'tmp/'+file.filename, "wb") as f: f.write(content) return {"filenames": [file.filename for file in files]} def read_img_file_as_base64(local_file) -> str: with open(local_file, "rb") as rf: base64_str = base64.b64encode(rf.read()) os.remove(local_file) return base64_str.decode() @app.get("/download/{file_name}") async def main(file_name): local_file = 'tmp/' + file_name file_path = 'result/' + file_name try: img = PIL.Image.open(local_file).convert('RGB') width, height = img.size print(width,height) img_t = T.ToTensor()(img) img_fast = Image(img_t) p, img_hr, b = learn.predict(img_fast) r = Image(img_hr) demo_show(r,figsize=(8,8), out_file=file_name) except Exception as e: print(e) finally: if os.path.exists(local_file): os.remove(local_file) return FileResponse(file_path, media_type="image/png") if __name__ == '__main__': uvicorn.run("app:app",host='0.0.0.0',port=8000)