Python3.8多进程共享内存之Numpy数组
在利用python处理数据的时候,想要充分发挥CPU的算力,可以选择利用多进程来实现。如果子进程内存占用较大的话,往往很难多开进程,如果不涉及对内存的写入操作,那么多个子进程共享内存,则可以减少内存的开销,多开子进程提高处理速度。
下面针对Python3.8多进程中共享Numpy数组进行记录
主要参考:
Python3.8多进程之共享内存 - 知乎 (zhihu.com)
Python 多进程编程《*》:shared_memory 模块-pudn.com
内存共享的目标是为了提高多进程通信的效率,因此在实践中,应该在每一进程中初始化一个SharedMemory实例对象,这些对象应该指向同一共享内存块。另外某些子进程可能会早于其它进程结束,因此在进程退出时,应该调用该进程内的 SharedMemory实例对象的close()方法——切断SharedMemory实例对象与共享内存块的连接。当所有进程都不在需要访问共享内存时,应该在最后一个退出进程中调用SharedMemory实例对象的unlink()方法——回收共享内存块。
利用multiprocessing shared_memory模块,主要思路为,在父进程开辟一块共享内存,然后在子进程中通过父进程中共享内存的名称对其进程访问与操作。
具体代码示例如下:
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 9 14:06:37 2022
@author: pan
"""
import time
import numpy as np
import multiprocessing as mul
from multiprocessing.shared_memory import SharedMemory
from multiprocessing.managers import SharedMemoryManager
import tracemalloc
def main_samp_extract(pars):
shm_tmp_name, shm_rad_name, shm_vpd_name = pars[0], pars[1], pars[2]
shape_tmp, shape_rad, shape_vpd = pars[3], pars[4], pars[5]
dtype_tmp, dtype_rad, dtype_vpd = pars[6], pars[7], pars[8]
i, j = pars[9], pars[10]
# Locate the shared memory by its name
shm_tmp_child = SharedMemory(name=shm_tmp_name, create=False)
shm_rad_child = SharedMemory(name=shm_rad_name, create=False)
shm_vpd_child = SharedMemory(name=shm_vpd_name, create=False)
# Create the np.ndarray from the buffer of the shared memory
tmp_ts = np.ndarray(shape_tmp, dtype_tmp, buffer=shm_tmp_child.buf)
rad_ts = np.ndarray(shape_rad, dtype_rad, buffer=shm_rad_child.buf)
vpd_ts = np.ndarray(shape_vpd, dtype_vpd, buffer=shm_vpd_child.buf)
vpd_ts[:,i:i+4,j:j+4] = tmp_ts[:,i:i+4,j:j+4] + rad_ts[:,i:i+4,j:j+4]
return 1
if __name__ == '__main__':
tmp_ts = np.random.rand(365,16,16)
rad_ts = np.random.rand(365,16,16)
vpd_ts = np.zeros((365,16,16))
# Start tracking memory usage
tracemalloc.start()
start_time = time.time()
# 将以上数据设置为共享内存,子进程可以使用
with SharedMemoryManager() as smm:
# Create shared memory of size ndarray.nbytes
shm_tmp = smm.SharedMemory(size=tmp_ts.nbytes)
shm_rad = smm.SharedMemory(size=rad_ts.nbytes)
shm_vpd = smm.SharedMemory(size=vpd_ts.nbytes)
# Create a np.ndarray using the buffer of shared memory
shm_tmp_ts = np.ndarray(tmp_ts.shape, dtype=tmp_ts.dtype, buffer=shm_tmp.buf)
shm_rad_ts = np.ndarray(rad_ts.shape, dtype=rad_ts.dtype, buffer=shm_rad.buf)
shm_vpd_ts = np.ndarray(vpd_ts.shape, dtype=vpd_ts.dtype, buffer=shm_vpd.buf)
# Copy the data into the shared memory
np.copyto(shm_tmp_ts, tmp_ts)
np.copyto(shm_rad_ts, rad_ts)
np.copyto(shm_vpd_ts, vpd_ts)
parameters = []
for i in range(0,16,4):
for j in range(0,16,4):
parameters.append([shm_tmp.name,shm_rad.name,shm_vpd.name,
tmp_ts.shape,rad_ts.shape,vpd_ts.shape,
tmp_ts.dtype,rad_ts.dtype,vpd_ts.dtype,
i,j])
##打开10个进程池,开始进行运算
pool = mul.Pool(10)
rel = pool.map(main_samp_extract, parameters)
pool.close()#关闭进程池,不再接受新的进程
pool.join()#主进程阻塞等待子进程的退出
print(rel)
# Check memory usage
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage {current/1e6}MB; Peak: {peak/1e6}MB")
print(f'Time elapsed: {time.time()-start_time:.2f}s')
tracemalloc.stop()
整体上是这么一个思路