import base64
|
|
import urllib
|
|
from uuid import uuid1
|
|
|
|
import PIL.Image
|
|
import torchvision.transforms as T
|
|
import matplotlib
|
|
|
|
matplotlib.use('Agg')
|
|
from fastai.vision import *
|
|
from fastai.vision import load_learner
|
|
|
|
from core import FeatureLoss
|
|
from typing import List
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from fastapi.responses import FileResponse
|
|
import uvicorn
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
import base64
|
|
|
|
learn = None
|
|
|
|
|
|
if not (os.path.exists('./ArtLine_650.pkl')):
|
|
MODEL_URL = "https://www.dropbox.com/s/starqc9qd2e1lg1/ArtLine_650.pkl?dl=1"
|
|
urllib.request.urlretrieve(MODEL_URL, "ArtLine_650.pkl")
|
|
|
|
# singleton start
|
|
def load_pkl(self) -> Any:
|
|
global learn
|
|
path = Path(".")
|
|
learn = load_learner(path, 'ArtLine_650.pkl')
|
|
|
|
|
|
PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
|
|
pl = PklLoader()
|
|
pl.load_pkl()
|
|
|
|
|
|
# singleton end
|
|
def demo_show(the_img: Image, ax: plt.Axes = None, figsize: tuple = None, title: Optional[str] = None,
|
|
hide_axis: bool = True,
|
|
cmap: str = None, y: Any = None, out_file: str = None, **kwargs):
|
|
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
|
|
cmap = ifnone(cmap, defaults.cmap)
|
|
ax = show_image(the_img, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
|
|
if y is not None: y.show(ax=ax, **kwargs)
|
|
if title is not None: ax.set_title(title)
|
|
ax.get_figure().savefig('result/' + out_file)
|
|
plt.close(ax.get_figure())
|
|
print('close')
|
|
|
|
|
|
app = FastAPI()
|
|
#设置允许访问的域名
|
|
origins = ["*"] #也可以设置为"*",即为所有。
|
|
|
|
#设置跨域传参
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins, #设置允许的origins来源
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
|
|
allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
|
|
|
|
|
|
|
|
@app.post("/uploadfiles/")
|
|
async def create_upload_files(files: List[UploadFile] = File(...)):
|
|
for file in files:
|
|
content = await file.read()
|
|
with open(r'tmp/'+file.filename, "wb") as f:
|
|
f.write(content)
|
|
return {"filenames": [file.filename for file in files]}
|
|
|
|
|
|
|
|
def read_img_file_as_base64(local_file) -> str:
|
|
with open(local_file, "rb") as rf:
|
|
base64_str = base64.b64encode(rf.read())
|
|
os.remove(local_file)
|
|
return base64_str.decode()
|
|
|
|
|
|
|
|
@app.get("/download/{file_name}")
|
|
async def main(file_name):
|
|
local_file = 'tmp/' + file_name
|
|
file_path = 'result/' + file_name
|
|
try:
|
|
img = PIL.Image.open(local_file).convert('RGB')
|
|
width, height = img.size
|
|
print(width,height)
|
|
img_t = T.ToTensor()(img)
|
|
img_fast = Image(img_t)
|
|
|
|
p, img_hr, b = learn.predict(img_fast)
|
|
r = Image(img_hr)
|
|
demo_show(r,figsize=(8,8), out_file=file_name)
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
finally:
|
|
if os.path.exists(local_file):
|
|
os.remove(local_file)
|
|
return FileResponse(file_path, media_type="image/png")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
uvicorn.run("app:app",host='0.0.0.0',port=8000)
|