Não pode escolher mais do que 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.
 
 
 
 

112 linhas
3.2 KiB

import base64
import urllib
from uuid import uuid1
import PIL.Image
import torchvision.transforms as T
import matplotlib
matplotlib.use('Agg')
from fastai.vision import *
from fastai.vision import load_learner
from core import FeatureLoss
from typing import List
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import base64
learn = None
if not (os.path.exists('./ArtLine_650.pkl')):
MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1"
urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl")
# singleton start
def load_pkl(self) -> Any:
global learn
path = Path(".")
learn = load_learner(path, 'ArtLine_650.pkl')
PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
pl = PklLoader()
pl.load_pkl()
# singleton end
def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None,
hide_axis: bool = True,
cmap: str = None, y: Any = None, out_file: str = None, **kwargs):
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
cmap = ifnone(cmap, defaults.cmap)
ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
if y is not None: y.show(ax=ax, **kwargs)
if title is not None: ax.set_title(title)
ax.get_figure().savefig('result/' + out_file)
plt.close(ax.get_figure())
print('close')
app = FastAPI()
#设置允许访问的域名
origins = ["*"] #也可以设置为"*",即为所有。
#设置跨域传参
app.add_middleware(
CORSMiddleware,
allow_origins=origins, #设置允许的origins来源
allow_credentials=True,
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
@app.post("/uploadfiles/")
async def create_upload_files(files: List[UploadFile] = File(...)):
for file in files:
content = await file.read()
with open(r'tmp/'+file.filename, "wb") as f:
f.write(content)
return {"filenames": [file.filename for file in files]}
def read_img_file_as_base64(local_file) -> str:
with open(local_file, "rb") as rf:
base64_str = base64.b64encode(rf.read())
os.remove(local_file)
return base64_str.decode()
@app.get("/download/{file_name}")
async def main(file_name):
local_file = 'tmp/' + file_name
file_path = 'result/' + file_name
try:
img = PIL.Image.open(local_file).convert('RGB')
width, height = img.size
print(width,height)
img_t = T.ToTensor()(img)
img_fast = Image(img_t)
p, img_hr, b = learn.predict(img_fast)
r = Image(img_hr)
demo_show(r,figsize=(8,8), out_file=file_name)
except Exception as e:
print(e)
finally:
if os.path.exists(local_file):
os.remove(local_file)
return FileResponse(file_path, media_type="image/png")
if __name__ == '__main__':
uvicorn.run("app:app",host='0.0.0.0',port=8000)