松鼠的博客

导航

http上传协议之文件流实现,轻松支持大文件上传

最近在公司进行业务开发时遇到了一些问题,当需要上传一个较大的文件时,经常会遇到内存被大量占用的情况。公司之前使用的web框架是一个老前辈实现的。在实现multipart/form-data类型的post请求解析时, 是将post请求体一次性读到内存中再做解析的,从而导致内存占用过大。而我之前为公司开发的框架是基于apistar这个asgi框架的,而apistar在解析mutilpart时使用的时flask作者编写的flask和django在对待multipart报文解析使用的方案基本是一致的,通过持续的解析请求体,将解析出来的文件内容放入一个工厂类创建的类文件对象中,工厂类在django中返回uploader的子类,在flask中叫作stream_factory。可以使用基于内存的,也可以使用基于临时文件的。

但是apistar作者在借用werkzeug的FormDataParser解析时,却直接将一个BytesIO传入了!而BytesIO中存放的是全量请求体,这势必会全部存在于内存中!那么带来的问题就是,当上传大文件时,内存会被撑爆!代码如下:

   class MultiPartCodec(BaseCodec):
    media_type = 'multipart/form-data'

    def decode(self, bytestring, headers, **options):
        try:
            content_length = max(0, int(headers['content-length']))
        except (KeyError, ValueError, TypeError):
            content_length = None

        try:
            mime_type, mime_options = parse_options_header(headers['content-type'])
        except KeyError:
            mime_type, mime_options = '', {}

        body_file = BytesIO(bytestring)
        parser = FormDataParser()
        stream, form, files = parser.parse(body_file, mime_type, content_length, mime_options)
        return ImmutableMultiDict(chain(form.items(multi=True), files.items(multi=True)))

其实想必这也是不得已的事情,因为apistar支持ASGI协议,这就导致了每次请求IO都是异步的,异步read接口和同步接口调用方式肯定不一样,所以作者想偷懒不自己实现一套异步解析方案,那么只能么做。

作者想偷懒我可以理解,但是公司对我的要求让我感觉鸭梨山大,之前基于s3文件的上传服务是由我开发的,使用的框架也是我依赖apistar开发的star_builder,现在公司要求废弃掉公司之前的文件上传服务(也就是基于老前辈web框架开发的那个),将所有接口全部转移到我开发的服务上来。那么势必要求我一并解决掉大文件上传的问题。所以没有办法,只能为apistar的作者造个轮子接上先用着了。

在我简单了解了multipart/form-data协议之后,实现了一个FileStream类和File类,每个类都返回可异步迭代对象,FileStream迭代File对象,File对象迭代数据,迭代过程实时解析请求体,实时发现文件对象,实时处理得到的文件数据。以这种方式处理上传的文件,对内存不会产生任何压力。

FIleStream的实现如下:

class FileStream(object):

    def __init__(self, receive, boundary):
        self.receive = receive
        self.boundary = boundary
        self.body = b""
        self.closed = False

    def __aiter__(self):
        return self

    async def __anext__(self):
        return await File.from_boundary(self, self.receive, self.boundary)

FileStream支持异步迭代,每次返回一个File对象。同时FIleStream存储已读但未返回到应用层的请求体数据。

File的实现如下:

class File(object):
    mime_type_regex = re.compile(b"Content-Type: (.*)")
    disposition_regex = re.compile(
        rb"Content-Disposition: form-data;"
        rb"(?: name=\"(?P<name>[^;]*?)\")?"
        rb"(?:; filename\*?=\"?"
        rb"(?:(?P<enc>.+?)'"
        rb"(?P<lang>\w*)')?"
        rb"(?P<filename>[^\"]*)\"?)?")

    def __init__(self, stream, receive, boundary, name, filename, mimetype):
        self.mimetype = mimetype
        self.receive = receive
        self.filename = filename
        self.name = name
        self.stream = stream
        self.tmpboundary = b"\r\n--" + boundary
        self.boundary_len = len(self.tmpboundary)
        self._last = b""
        self._size = 0
        self.body_iter = self._iter_content()

    def __aiter__(self):
        return self.body_iter

    def __str__(self):
        return f"<{self.__class__.__name__} " \
               f"name={self.name} " \
               f"filename={self.filename} >"

    __repr__ = __str__

    def iter_content(self):
        return self.body_iter

    async def _iter_content(self):
        stream = self.stream
        while True:
            # 如果存在read过程中剩下的,则直接返回
            if self._last:
                yield self._last
                continue

            index = self.stream.body.find(self.tmpboundary)
            if index != -1:
                # 找到分隔线,返回分隔线前的数据
                # 并将分隔及分隔线后的数据返回给stream
                read, stream.body = stream.body[:index], stream.body[index:]
                self._size += len(read)
                yield read
                if self._last:
                    yield self._last
                break
            else:
                if self.stream.closed:
                    raise RuntimeError("Uncomplete content!")
                # 若没有找到分隔线,为了防止分隔线被读取了一半
                # 选择只返回少于分隔线长度的部分body
                read = stream.body[:-self.boundary_len]
                stream.body = stream.body[-self.boundary_len:]
                self._size += len(read)
                yield read
                await self.get_message(self.receive, stream)

    async def read(self, size=10240):
        read = b""
        assert size > 0, (999, "Read size must > 0")
        while len(read) < size:
            try:
                buffer = await self.body_iter.asend(None)
            except StopAsyncIteration:
                return read
            read = read + buffer
            read, self._last = read[:size], read[size:]
        return read

    @staticmethod
    async def get_message(receive, stream):
        message = await receive()

        if not message['type'] == 'http.request':
            raise RuntimeError(
                f"Unexpected ASGI message type: {message['type']}.")

        if not message.get('more_body', False):
            stream.closed = True
        stream.body += message.get("body", b"")

    def tell(self):
        return self._size

    @classmethod
    async def from_boundary(cls, stream, receive, boundary):
        tmp_boundary = b"--" + boundary
        while not stream.closed:
            await cls.get_message(receive, stream)

            if b"\r\n\r\n" in stream.body and tmp_boundary in stream.body or \
                    stream.closed:
                break

        return cls(stream, receive, boundary,
                   *cls.parse_headers(stream, tmp_boundary))

    @classmethod
    def parse_headers(cls, stream, tmp_boundary):
        end_boundary = tmp_boundary + b"--"
        body = stream.body
        index = body.find(tmp_boundary)
        if index == body.find(end_boundary):
            raise StopAsyncIteration
        body = body[index + len(tmp_boundary):]
        header_str = body[:body.find(b"\r\n\r\n")]
        body = body[body.find(b"\r\n\r\n") + 4:]
        groups = cls.disposition_regex.search(header_str).groupdict()
        filename = groups["filename"] and unquote(groups["filename"].decode())
        if groups["enc"]:
            filename = filename.encode().decode(groups["enc"].decode())
        name = groups["name"].decode()

        mth = cls.mime_type_regex.search(header_str)
        mimetype = mth and mth.group(1).decode()
        stream.body = body
        assert name, "FileStream iterated without File consumed. "
        return name, filename, mimetype

File实例也是一个异步可迭代对象,每次迭代从receive中实时获取数据,同时File还支持异步read,但read本质上也是对File对象的迭代。

那么正确的使用姿势是怎样的呢?

下面是star_builder构建的项目中关于FileStream在一次请求中action的demo实现。

@post("/test_upload")
    async def up(stream: FileStream):
        async for file in stream:
            if file.filename:
                with open(file.filename, "wb") as f:
                    async for chuck in file:
                        f.write(chuck)
            else:
                # 没有filename的是其它类型的form参数
                arg = await file.read()
                print(f"Form参数:{file.name}={arg.decode()}")

使用方法非常简单,不会生成临时文件,也不会占用内存来存储。实时异步从socket中读取数据,非要说有什么缺点的话,就是不全部迭代完的话,是无法知道这一次请求中一共上传了几个文件的。如果需要提前知道的话,可以通过前端配合通过url传入params参数来获取文件相关属性信息。

这种实时从socket读取的实现方案,应该是基于http协议性能最好的文件上传方案。欢迎评论区发表意见和建议。

 

参考文章:http://blog.ncmem.com/wordpress/2023/10/24/http%e4%b8%8a%e4%bc%a0%e5%8d%8f%e8%ae%ae%e4%b9%8b%e6%96%87%e4%bb%b6%e6%b5%81%e5%ae%9e%e7%8e%b0%ef%bc%8c%e8%bd%bb%e6%9d%be%e6%94%af%e6%8c%81%e5%a4%a7%e6%96%87%e4%bb%b6%e4%b8%8a%e4%bc%a0/

欢迎入群一起讨论

 

 

posted on 2023-10-24 13:48  Xproer-松鼠  阅读(568)  评论(0编辑  收藏  举报