#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :IoD_data_analysis_tool @File :producer_consumer.py @IDE :PyCharm @Author :rengengchen @Time :2022/8/5 11:53 """ import multiprocessing from typing import Iterable, Callable from tqdm import tqdm class Stop: pass class AbstractPCConcurrencySystem: """ @todo 对启动进程的维护 @todo 进程数量 """ def __init__(self, num_producer: int = 1, num_consumer: int = 1, num_callback: int = 0, len_task_queue: int = 0, len_result_queue: int = 0, len_callback_queue: int = 0, producer_lock=None, consumer_lock=None, callback_lock=None, meta=None, enable_progressbar=False, num_total_result=None): self.task_queue = multiprocessing.Queue(len_task_queue) self.num_producer = num_producer self.num_consumer = num_consumer self.num_callback = num_callback self.producer_lock = producer_lock or multiprocessing.Lock() self.consumer_lock = consumer_lock or multiprocessing.Lock() self.meta = meta self.enable_progressbar = enable_progressbar if enable_progressbar and self.num_callback == 0: self.num_callback = 1 self.result_queue = multiprocessing.Queue(len_result_queue) if self.num_callback: self.callback_lock = callback_lock or multiprocessing.Lock() self.num_total_result = num_total_result self.callback_queue = multiprocessing.Queue(len_callback_queue) def get_result(self): return self.callback_queue.get() def produce(self): """ Must return an iterable object or a Stop object. """ raise NotImplementedError def consume(self, consumer_params): """ @return: task result or Stop() """ raise NotImplementedError def callback(self, result): return result def _produce(self): producer = self.produce() if isinstance(producer, Iterable): for params in producer: self.task_queue.put(params, block=True) stop = Stop() for _ in range(self.num_consumer): self.task_queue.put(stop, block=True) elif isinstance(producer, Callable): while True: task = producer() if isinstance(task, Stop): break self.task_queue.put(task, block=True) def _consume(self): consumer_params = self.task_queue.get(block=True) while not isinstance(consumer_params, Stop): info = self.consume(consumer_params) self.result_queue.put(info) consumer_params = self.task_queue.get(block=True) self.result_queue.put(Stop()) def _callback(self): if self.enable_progressbar: bar = tqdm(total=self.num_total_result) over_flag = 0 while over_flag < self.num_consumer: result = self.result_queue.get(block=True) if isinstance(result, Stop): over_flag += 1 else: callback = self.callback(result) self.callback_queue.put(callback) if self.enable_progressbar: bar.update(1) else: if self.enable_progressbar: bar.close() def run(self): consumers = [] callbackers = [] # 创建并启动生产者 for i in range(self.num_producer): multiprocessing.Process(target=self._produce, name=f'producer_{i}').start() # 创建并启动消费者 for i in range(self.num_consumer): p = multiprocessing.Process(target=self._consume, name=f'consumer_{i}') consumers.append(p) p.start() # 处理结果 if self.num_callback: for i in range(self.num_callback): p = multiprocessing.Process(target=self._callback, name=f'callback_{i}') callbackers.append(p) p.start() return self def close(self): self.task_queue.close() self.result_queue.close() self.callback_queue.close()