python代码打包为http服务端接口 (aiohttp)

目录

一、需求
二、服务器端
三、客户端


一、需求

python端代码接受一个待处理得模型文件路径,对模型进行预测并得到相应结果,由于无法直接将python转换为C++,这里采用aiohttp库将python打包打包成http服务端接口,支持客户端传入文件路径、将文件base64编码以及上传文件三种方式进行请求,便于其他语言调用。并编写python客户端代码进行测试,也可以利用postman工具进行测试。aiohttp是一个异步的库,具体的介绍可以参照官网,里面介绍得很详细。

二、服务器端

这里只对"application/json"和"multipart/form-data"两种类型进行处理。

from aiohttp import web
from inference_class import InferenceClass
import json
import asyncio
import os
import base64


def image_from_base64(base64_utf8):
    decode_data = base64.decodebytes(base64_utf8.encode('utf-8'))
    return decode_data


class MeshWebServer(object):
    def __init__(self, max_request=1, cache_dir=None):
        self._app = web.Application()
        self._engine = InferenceClass()
        self._concurrency = asyncio.BoundedSemaphore(max_request)
        self._lock = asyncio.Lock()
        if cache_dir is not None:
            self._cache_dir = cache_dir
        else:
            self._cache_dir = os.path.join(os.getcwd(), 'cache')

    def run(self, port):
        self._app.add_routes([
            web.post('/mesh/recognize', self.__on_recognize)
        ])
        web.run_app(self._app, port=port)

    async def __on_recognize(self, request):
        obj_data, filename, file_path = "", "", ""
        content_type = request.content_type
        # print("type", content_type)
        if content_type == "application/json":
            try:
                data = await request.json()
                if 'obj' in data:
                    obj_data = image_from_base64(data['obj'])
                if 'filename' in data:
                    filename = data['filename']
                else:
                    filename = "temp.obj"
                if "file_path" in data:
                    file_path = data["file_path"]
                    filename = os.path.basename(file_path)

                if os.path.isfile(file_path):
                    file = file_path
                else:
                    if obj_data != "":
                        os.makedirs(self._cache_dir, exist_ok=True)
                        del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                        for f in del_file_list:
                            file_path = os.path.join(self._cache_dir, f)
                            if os.path.isfile(file_path):
                                os.remove(file_path)

                        file = os.path.join(self._cache_dir, filename)
                        with open(file, "wb") as f:
                            f.write(obj_data)
                    else:
                        file = ""
                async with self._lock:
                    predict_class = self._engine.inference(file)
                    respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)

            except Exception as e:
                respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
        elif content_type == "multipart/form-data":
            try:
                print("headers: ", request.headers)

                reader = await request.multipart()
                field = await reader.next()
                filename = field.filename if field.filename else "temp.obj"
                size = 0
                os.makedirs(self._cache_dir, exist_ok=True)
                del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                for f in del_file_list:
                    file_path = os.path.join(self._cache_dir, f)
                    if os.path.isfile(file_path):
                        os.remove(file_path)
                file = os.path.join(self._cache_dir, filename)
                with open(file, 'wb') as f:
                    while True:
                        chunk = await field.read_chunk()  # 默认是8192个字节。
                        if not chunk:
                            break
                        size += len(chunk)
                        f.write(chunk)

                # # ----小文件----
                # data = await request.post()
                # file_data = data["file"]
                # file = file_data.file
                # filename = file_data.filename
                # content = file.read()
                #
                # os.makedirs(self._cache_dir, exist_ok=True)
                # del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                # for f in del_file_list:
                #     file_path = os.path.join(self._cache_dir, f)
                #     if os.path.isfile(file_path):
                #         os.remove(file_path)
                #
                # file = os.path.join(self._cache_dir, filename)
                # with open(file, "wb") as f:
                #     f.write(content)

                async with self._lock:
                    predict_class = self._engine.inference(file)
                    respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)

            except Exception as e:
                print(e)
                respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
        # elif content_type == "application/octet-stream":
        #     print("Enter octet, headers: ", request.headers)
        #     data = await request.post()
        #     respond = dict(text="octet-stream", returnCode="Successed!", filename=filename)
        #     print("data", data, dir(data), data.values)
        else:
            respond = dict(text="Unknown content type, just support application/json and multipart/form-data",
                           returnCode="Failed!", filename=filename)
        print("---** predict is {} **---".format(respond["returnCode"]))
        return web.json_response(json.dumps(respond))

三、客户端

3.1 "application/json"

可以通过json方式传递文件名或者base64编码

import aiohttp
import asyncio
import base64
import os
import json
import time


def base64_from_filename(filename):
    with open(filename, "rb") as file_binary:
        data = file_binary.read()
        encoded = base64.b64encode(data)
        encoded_utf8 = encoded.decode('utf-8')
    return encoded_utf8


async def do_recognize(web_ip, web_port, file_path, obj_data=None):
    try:
        codename = os.path.basename(file_path)
        if web_ip == "127.0.0.1":
            request = dict(file_path=file_path, filename=codename)
        else:
            request = dict(obj=obj_data, filename=codename)
        print('filename: {}'.format(file_path))
        async with aiohttp.ClientSession() as session:
            async with session.post(url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
                                    data=json.dumps(request),
                                    headers={'Content-Type': 'application/json; charset=utf-8'}) as resp:
                respond = await resp.text()
                respond = respond.replace('\\"', '"')
                respond = respond[1:-1]   # remove " at the begin and end
                result = json.loads(respond)
                data = result.get("text")
                status = result.get("returnCode")
                print("predict {}, res: {}".format(status, data))
    except:
        print("do_recognize Error {} \n".format(filename))


def run_test(web_ip, web_port, filename, loop):
    try:
        if web_ip == "127.0.0.1":
            loop.run_until_complete(do_recognize(web_ip, web_port, filename))
        else:
            b64_str = base64_from_filename(filename)
            loop.run_until_complete(do_recognize(web_ip, web_port, filename, b64_str))
    except:
       print("run_test Error")


if __name__ == '__main__':
    file_dir = "E:/code/test_models/"
    filenames = os.listdir(file_dir)
    files = [os.path.join(file_dir, filename) for filename in filenames]
    ip = "192.168.107.118"  # "127.0.0.1" 192.168.107.118
    port = 8000
    start = time.time()
    try:
        event_loop = asyncio.get_event_loop()
        tasks = [run_test(ip, port, filename, event_loop) for filename in files]
        event_loop.close()
    except:
        print("__main__ Error")
    end = time.time()
    print("run time is : {}s".format(end - start))


3.2 "multipart/form-data"

直接上传文件

import aiohttp
from aiohttp import formdata
import asyncio
import os
import time
import json


async def do_recognize(web_ip, web_port, file_path):
    try:
        codename = os.path.basename(file_path)
        print('filename: {}'.format(file_path))
        # file_data = {"file": open(file_path, "rb")}
        file_data = formdata.FormData()
        file_data.add_field('file',
                       open(file_path, 'rb'),
                       # content_type="multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1",
                       filename=codename)
        async with aiohttp.ClientSession() as session:
            async with session.post(
                    url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
                    data=file_data,
                    # headers={'Content-Type': 'multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1'}
            ) as resp:
                respond = await resp.text()
                respond = respond.replace('\\"', '"')
                respond = respond[1:-1]   # remove " at the begin and end
                result = json.loads(respond)
                data = result.get("text")
                status = result.get("returnCode")
                print("predict {}, res: {}".format(status, data))
    except:
        print("do_recognize Error {} \n".format(file_path))


def run_test(web_ip, web_port, filename, loop):
    try:
        loop.run_until_complete(do_recognize(web_ip, web_port, filename))
    except:
       print ("run_test Error")


if __name__ == '__main__':
    file_dir = "E:/code/test_models/"
    filenames = os.listdir(file_dir)
    files = [os.path.join(file_dir, filename) for filename in filenames]
    ip = "192.168.107.118"  # "127.0.0.1"
    port = 8000
    start = time.time()
    try:
        event_loop = asyncio.get_event_loop()
        tasks = [run_test(ip, port, filename, event_loop) for filename in files]
        event_loop.close()
    except:
        print("__main__ Error")
    end = time.time()
    print("run time is : {}s".format(end - start))


参考链接

https://docs.aiohttp.org/en/stable/
https://blog.csdn.net/weixin_39643613/article/details/109171090
https://docs.aiohttp.org/en/stable/web.html

posted @ 2021-03-25 19:37  半夜打老虎  阅读(2459)  评论(0编辑  收藏  举报