python代码打包为http服务端接口 (aiohttp)
目录
一、需求
python端代码接受一个待处理得模型文件路径,对模型进行预测并得到相应结果,由于无法直接将python转换为C++,这里采用aiohttp
库将python打包打包成http服务端接口,支持客户端传入文件路径、将文件base64编码以及上传文件三种方式进行请求,便于其他语言调用。并编写python客户端代码进行测试,也可以利用postman
工具进行测试。aiohttp
是一个异步的库,具体的介绍可以参照官网,里面介绍得很详细。
二、服务器端
这里只对"application/json"和"multipart/form-data"两种类型进行处理。
from aiohttp import web
from inference_class import InferenceClass
import json
import asyncio
import os
import base64
def image_from_base64(base64_utf8):
decode_data = base64.decodebytes(base64_utf8.encode('utf-8'))
return decode_data
class MeshWebServer(object):
def __init__(self, max_request=1, cache_dir=None):
self._app = web.Application()
self._engine = InferenceClass()
self._concurrency = asyncio.BoundedSemaphore(max_request)
self._lock = asyncio.Lock()
if cache_dir is not None:
self._cache_dir = cache_dir
else:
self._cache_dir = os.path.join(os.getcwd(), 'cache')
def run(self, port):
self._app.add_routes([
web.post('/mesh/recognize', self.__on_recognize)
])
web.run_app(self._app, port=port)
async def __on_recognize(self, request):
obj_data, filename, file_path = "", "", ""
content_type = request.content_type
# print("type", content_type)
if content_type == "application/json":
try:
data = await request.json()
if 'obj' in data:
obj_data = image_from_base64(data['obj'])
if 'filename' in data:
filename = data['filename']
else:
filename = "temp.obj"
if "file_path" in data:
file_path = data["file_path"]
filename = os.path.basename(file_path)
if os.path.isfile(file_path):
file = file_path
else:
if obj_data != "":
os.makedirs(self._cache_dir, exist_ok=True)
del_file_list = os.listdir(self._cache_dir) # 存在文件的话先清空
for f in del_file_list:
file_path = os.path.join(self._cache_dir, f)
if os.path.isfile(file_path):
os.remove(file_path)
file = os.path.join(self._cache_dir, filename)
with open(file, "wb") as f:
f.write(obj_data)
else:
file = ""
async with self._lock:
predict_class = self._engine.inference(file)
respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)
except Exception as e:
respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
elif content_type == "multipart/form-data":
try:
print("headers: ", request.headers)
reader = await request.multipart()
field = await reader.next()
filename = field.filename if field.filename else "temp.obj"
size = 0
os.makedirs(self._cache_dir, exist_ok=True)
del_file_list = os.listdir(self._cache_dir) # 存在文件的话先清空
for f in del_file_list:
file_path = os.path.join(self._cache_dir, f)
if os.path.isfile(file_path):
os.remove(file_path)
file = os.path.join(self._cache_dir, filename)
with open(file, 'wb') as f:
while True:
chunk = await field.read_chunk() # 默认是8192个字节。
if not chunk:
break
size += len(chunk)
f.write(chunk)
# # ----小文件----
# data = await request.post()
# file_data = data["file"]
# file = file_data.file
# filename = file_data.filename
# content = file.read()
#
# os.makedirs(self._cache_dir, exist_ok=True)
# del_file_list = os.listdir(self._cache_dir) # 存在文件的话先清空
# for f in del_file_list:
# file_path = os.path.join(self._cache_dir, f)
# if os.path.isfile(file_path):
# os.remove(file_path)
#
# file = os.path.join(self._cache_dir, filename)
# with open(file, "wb") as f:
# f.write(content)
async with self._lock:
predict_class = self._engine.inference(file)
respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)
except Exception as e:
print(e)
respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
# elif content_type == "application/octet-stream":
# print("Enter octet, headers: ", request.headers)
# data = await request.post()
# respond = dict(text="octet-stream", returnCode="Successed!", filename=filename)
# print("data", data, dir(data), data.values)
else:
respond = dict(text="Unknown content type, just support application/json and multipart/form-data",
returnCode="Failed!", filename=filename)
print("---** predict is {} **---".format(respond["returnCode"]))
return web.json_response(json.dumps(respond))
三、客户端
3.1 "application/json"
可以通过json方式传递文件名或者base64编码
import aiohttp
import asyncio
import base64
import os
import json
import time
def base64_from_filename(filename):
with open(filename, "rb") as file_binary:
data = file_binary.read()
encoded = base64.b64encode(data)
encoded_utf8 = encoded.decode('utf-8')
return encoded_utf8
async def do_recognize(web_ip, web_port, file_path, obj_data=None):
try:
codename = os.path.basename(file_path)
if web_ip == "127.0.0.1":
request = dict(file_path=file_path, filename=codename)
else:
request = dict(obj=obj_data, filename=codename)
print('filename: {}'.format(file_path))
async with aiohttp.ClientSession() as session:
async with session.post(url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
data=json.dumps(request),
headers={'Content-Type': 'application/json; charset=utf-8'}) as resp:
respond = await resp.text()
respond = respond.replace('\\"', '"')
respond = respond[1:-1] # remove " at the begin and end
result = json.loads(respond)
data = result.get("text")
status = result.get("returnCode")
print("predict {}, res: {}".format(status, data))
except:
print("do_recognize Error {} \n".format(filename))
def run_test(web_ip, web_port, filename, loop):
try:
if web_ip == "127.0.0.1":
loop.run_until_complete(do_recognize(web_ip, web_port, filename))
else:
b64_str = base64_from_filename(filename)
loop.run_until_complete(do_recognize(web_ip, web_port, filename, b64_str))
except:
print("run_test Error")
if __name__ == '__main__':
file_dir = "E:/code/test_models/"
filenames = os.listdir(file_dir)
files = [os.path.join(file_dir, filename) for filename in filenames]
ip = "192.168.107.118" # "127.0.0.1" 192.168.107.118
port = 8000
start = time.time()
try:
event_loop = asyncio.get_event_loop()
tasks = [run_test(ip, port, filename, event_loop) for filename in files]
event_loop.close()
except:
print("__main__ Error")
end = time.time()
print("run time is : {}s".format(end - start))
3.2 "multipart/form-data"
直接上传文件
import aiohttp
from aiohttp import formdata
import asyncio
import os
import time
import json
async def do_recognize(web_ip, web_port, file_path):
try:
codename = os.path.basename(file_path)
print('filename: {}'.format(file_path))
# file_data = {"file": open(file_path, "rb")}
file_data = formdata.FormData()
file_data.add_field('file',
open(file_path, 'rb'),
# content_type="multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1",
filename=codename)
async with aiohttp.ClientSession() as session:
async with session.post(
url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
data=file_data,
# headers={'Content-Type': 'multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1'}
) as resp:
respond = await resp.text()
respond = respond.replace('\\"', '"')
respond = respond[1:-1] # remove " at the begin and end
result = json.loads(respond)
data = result.get("text")
status = result.get("returnCode")
print("predict {}, res: {}".format(status, data))
except:
print("do_recognize Error {} \n".format(file_path))
def run_test(web_ip, web_port, filename, loop):
try:
loop.run_until_complete(do_recognize(web_ip, web_port, filename))
except:
print ("run_test Error")
if __name__ == '__main__':
file_dir = "E:/code/test_models/"
filenames = os.listdir(file_dir)
files = [os.path.join(file_dir, filename) for filename in filenames]
ip = "192.168.107.118" # "127.0.0.1"
port = 8000
start = time.time()
try:
event_loop = asyncio.get_event_loop()
tasks = [run_test(ip, port, filename, event_loop) for filename in files]
event_loop.close()
except:
print("__main__ Error")
end = time.time()
print("run time is : {}s".format(end - start))
参考链接
https://docs.aiohttp.org/en/stable/
https://blog.csdn.net/weixin_39643613/article/details/109171090
https://docs.aiohttp.org/en/stable/web.html