Python3.8多进程共享内存之Numpy数组

在利用python处理数据的时候,想要充分发挥CPU的算力,可以选择利用多进程来实现。如果子进程内存占用较大的话,往往很难多开进程,如果不涉及对内存的写入操作,那么多个子进程共享内存,则可以减少内存的开销,多开子进程提高处理速度。

下面针对Python3.8多进程中共享Numpy数组进行记录

主要参考:

Python3.8多进程之共享内存 - 知乎 (zhihu.com)

[Python 多进程编程《*》:shared_memory 模块-pudn.com]( https://www.pudn.com/news/630c29cb88df2007aaec6b43.html#:~:text=前言 multiprocessing.shared_memory 模块是 Python3.8 引入的新功能 ,%EF%BC%8C%E7%9B%AE%E7%9A%84%E6%98%AF%E4%B8%BA%E4%BA%86%E5%A4%9A%E8%BF%9B%E7%A8%8B%E7%BC%96%E7%A8%8B%E6%8F%90%E4%BE%9B%E5%85%B1%E4%BA%AB%E5%86%85%E5%AD%98%E5%8A%9F%E8%83%BD%EF%BC%8C%E8%AF%A5%E6%A8%A1%E5%9D%97%E4%B8%BB%E8%A6%81%E5%8C%85%E5%90%AB%E4%B8%A4%E4%B8%AA%E7%B1%BB%20SharedMemory%20%E4%B8%8E%20SharebleList%EF%BC%8C%20%E5%90%8E%E8%80%85%E5%9C%A8%E5%89%8D%E8%80%85%E7%9A%84%E5%9F%BA%E7%A1%80%E4%B9%8B%E4%B8%8A%E8%BF%9B%E4%B8%80%E6%AD%A5%E8%BF%9B%E8%A1%8C%E4%BA%86%E5%B0%81%E8%A3%85%E3%80%82)

multiprocessing.shared_memory --- Shared memory for direct access across processes — Python 3.11.0 文档

内存共享的目标是为了提高多进程通信的效率,因此在实践中,应该在每一进程中初始化一个SharedMemory实例对象,这些对象应该指向同一共享内存块。另外某些子进程可能会早于其它进程结束,因此在进程退出时,应该调用该进程内的 SharedMemory实例对象的close()方法——切断SharedMemory实例对象与共享内存块的连接。当所有进程都不在需要访问共享内存时,应该在最后一个退出进程中调用SharedMemory实例对象的unlink()方法——回收共享内存块。

利用multiprocessing shared_memory模块,主要思路为,在父进程开辟一块共享内存,然后在子进程中通过父进程中共享内存的名称对其进程访问与操作。

具体代码示例如下:


# -*- coding: utf-8 -*-
"""
Created on Wed Nov  9 14:06:37 2022

@author: pan
"""

import time
import numpy as np
import multiprocessing as mul
from multiprocessing.shared_memory import SharedMemory
from multiprocessing.managers import SharedMemoryManager
import tracemalloc

def main_samp_extract(pars):
    shm_tmp_name, shm_rad_name, shm_vpd_name = pars[0], pars[1], pars[2]
    shape_tmp, shape_rad, shape_vpd = pars[3], pars[4], pars[5]
    dtype_tmp, dtype_rad, dtype_vpd = pars[6], pars[7], pars[8]
    i, j = pars[9], pars[10]
    # Locate the shared memory by its name
    shm_tmp_child = SharedMemory(name=shm_tmp_name, create=False)
    shm_rad_child = SharedMemory(name=shm_rad_name, create=False)
    shm_vpd_child = SharedMemory(name=shm_vpd_name, create=False)
    # Create the np.ndarray from the buffer of the shared memory
    tmp_ts = np.ndarray(shape_tmp, dtype_tmp, buffer=shm_tmp_child.buf)
    rad_ts = np.ndarray(shape_rad, dtype_rad, buffer=shm_rad_child.buf)
    vpd_ts = np.ndarray(shape_vpd, dtype_vpd, buffer=shm_vpd_child.buf)
    vpd_ts[:,i:i+4,j:j+4] = tmp_ts[:,i:i+4,j:j+4] + rad_ts[:,i:i+4,j:j+4]    
    return 1

if __name__ == '__main__':
        tmp_ts = np.random.rand(365,16,16)
        rad_ts = np.random.rand(365,16,16)
        vpd_ts = np.zeros((365,16,16))
        # Start tracking memory usage
        tracemalloc.start()
        start_time = time.time()
        # 将以上数据设置为共享内存,子进程可以使用
        with SharedMemoryManager() as smm:
            # Create shared memory of size ndarray.nbytes
            shm_tmp = smm.SharedMemory(size=tmp_ts.nbytes)
            shm_rad = smm.SharedMemory(size=rad_ts.nbytes)
            shm_vpd = smm.SharedMemory(size=vpd_ts.nbytes)
            # Create a np.ndarray using the buffer of shared memory
            shm_tmp_ts = np.ndarray(tmp_ts.shape, dtype=tmp_ts.dtype, buffer=shm_tmp.buf)
            shm_rad_ts = np.ndarray(rad_ts.shape, dtype=rad_ts.dtype, buffer=shm_rad.buf)
            shm_vpd_ts = np.ndarray(vpd_ts.shape, dtype=vpd_ts.dtype, buffer=shm_vpd.buf)
            # Copy the data into the shared memory
            np.copyto(shm_tmp_ts, tmp_ts)
            np.copyto(shm_rad_ts, rad_ts)
            np.copyto(shm_vpd_ts, vpd_ts)
            
            parameters = []
            for i in range(0,16,4):
                for j in range(0,16,4):
                    parameters.append([shm_tmp.name,shm_rad.name,shm_vpd.name,
                                       tmp_ts.shape,rad_ts.shape,vpd_ts.shape,
                                       tmp_ts.dtype,rad_ts.dtype,vpd_ts.dtype,
                                       i,j])

            ##打开10个进程池,开始进行运算
            pool = mul.Pool(10)
            rel = pool.map(main_samp_extract, parameters)    
            pool.close()#关闭进程池,不再接受新的进程
            pool.join()#主进程阻塞等待子进程的退出    
            print(rel)


        # Check memory usage
        current, peak = tracemalloc.get_traced_memory()
        print(f"Current memory usage {current/1e6}MB; Peak: {peak/1e6}MB")
        print(f'Time elapsed: {time.time()-start_time:.2f}s')
        tracemalloc.stop()  

整体上是这么一个思路

posted @ 2022-11-12 19:31  岁时  阅读(1657)  评论(0)    收藏  举报