python并行之flask-socketio
1、服务器端
from flask import * from flask_socketio import * from flask_socketio import SocketIO from nasbench_lib.nasbench_201 import NASBench201 import random import subprocess class Server: def __init__(self, gpu): self.app = Flask(__name__) self.socketio = SocketIO(self.app, ping_timeout=3600000, ping_interval=3600000, max_http_buffer_size=int(1e32)) self.gpu = gpu self.MIN_NUM_WORKERS = 1 self.current_round = -1 # -1 for not yet started self.NUM_CLIENTS_CONTACTED_PER_ROUND = 1 self.ready_client_sids = set() self.nas = NASBench201() # 设置 SocketIO 事件处理程序 self.register_handles() def check_client_resource(self): self.client_resource = {} client_sids_selected = random.sample(list(self.ready_client_sids), self.NUM_CLIENTS_CONTACTED_PER_ROUND) for rid in client_sids_selected: emit('check_client_resource', { 'round_number': self.current_round, }, room=rid) def register_handles(self): @self.socketio.on('connect') def handle_connect(): print(request.sid, "connected") @self.socketio.on('reconnect') def handle_reconnect(): print(request.sid, "reconnected") @self.socketio.on('disconnect') def handle_disconnect(): print(request.sid, "disconnected") if request.sid in self.ready_client_sids: self.ready_client_sids.remove(request.sid) @self.socketio.on('client_wake_up') def handle_wake_up(): print(f"服务器端被客户端{request.sid}唤醒.") emit('init') @self.socketio.on('client_ready') def handle_client_ready(): print(f"服务器收到客户端{request.sid}准备完毕。开始check资源。") self.ready_client_sids.add(request.sid) if len(self.ready_client_sids) >= self.MIN_NUM_WORKERS: self.check_client_resource() else: print("没有足够的客户端连接....") def handle_connect(self): print("Client connected") def handle_sample(self): return self.nas.generate_random_for_multiview(10) def get_slurm_allocated_gpus(self): result = subprocess.run([ 'nvidia-smi', '--query-gpu=memory.total,memory.used', '--format=csv,nounits,noheader' ],stdout=subprocess.PIPE) output = result.stdout.decode('utf-8').strip().split('\n') gpu_0_memory = output[self.gpu].split(',') total_memory = int(gpu_0_memory[0]) used_memory = int(gpu_0_memory[1]) free_memory = total_memory - used_memory return free_memory def run(self, host='0.0.0.0', port=5000): self.socketio.run(self.app, host=host, port=port) if __name__ == '__main__': server = Server(0) server.run()
2、客户端
import socketio class Worker: def __init__(self, server_host, server_port): self.sio = socketio.Client() self.server_url = f'http://{server_host}:{server_port}' self.register_handles() self.connect_to_server() def on_init(self): print('客户端进行初始化.') #加载模型 print("客户端本地模型加载完毕.") # ready to be dispatched for training self.sio.emit('client_ready') def register_handles(self): @self.sio.event def connect(): print('客户端请求连接...') self.sio.emit("client_wake_up") @self.sio.event def disconnect(): print('Disconnected') @self.sio.event def reconnect(): print('Reconnected') @self.sio.on('check_client_resource') def on_check_client_resource(*args): self.sio.emit('check_client_resource_done') self.sio.on('init', self.on_init) def connect_to_server(self): print("Connecting to server...") self.sio.connect(self.server_url) self.sio.wait() def request_sample_architecture(self): self.sio.emit('sample', callback=self.print_response) def request_evolve_architectures(self, architectures): self.sio.emit('evolve', architectures, callback=self.print_response) @staticmethod def print_response(data): print("Response from server:", data)
3、管理端
from client import Worker import torch.multiprocessing as mp class GPUManager(object): def __init__(self): self.p_count = 10 self.available_gpus = 6 self.port_list = [ 78901, 78902, ] def run_client(self): TIMEOUT = 48 * 3600 proc = list() for i in range(self.p_count): worker = Worker("127.0.0.1", 5000) p = mp.Process( target=worker.connect_to_server, ) proc.append(p) for p in proc: p.start() for p in proc: p.join(timeout=TIMEOUT) for p in proc: if p.is_alive(): p.terminate() p.join() for p in proc: p.close() def run(self): worker = Worker("127.0.0.1", 5000) worker.connect_to_server() if __name__ == '__main__': GPUManager().run()