util/lib/analysis_package/code_template/concurrency/producer_consumer.py

128 lines
4.1 KiB
Python
Raw Normal View History

2024-05-12 12:18:24 +00:00
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project IoD_data_analysis_tool
@File producer_consumer.py
@IDE PyCharm
@Author rengengchen
@Time 2022/8/5 11:53
"""
import multiprocessing
from typing import Iterable, Callable
from tqdm import tqdm
class Stop:
pass
class AbstractPCConcurrencySystem:
"""
@todo 对启动进程的维护
@todo 进程数量
"""
def __init__(self, num_producer: int = 1, num_consumer: int = 1, num_callback: int = 0,
len_task_queue: int = 0, len_result_queue: int = 0, len_callback_queue: int = 0,
producer_lock=None, consumer_lock=None, callback_lock=None,
meta=None, enable_progressbar=False, num_total_result=None):
self.task_queue = multiprocessing.Queue(len_task_queue)
self.num_producer = num_producer
self.num_consumer = num_consumer
self.num_callback = num_callback
self.producer_lock = producer_lock or multiprocessing.Lock()
self.consumer_lock = consumer_lock or multiprocessing.Lock()
self.meta = meta
self.enable_progressbar = enable_progressbar
if enable_progressbar and self.num_callback == 0:
self.num_callback = 1
self.result_queue = multiprocessing.Queue(len_result_queue)
if self.num_callback:
self.callback_lock = callback_lock or multiprocessing.Lock()
self.num_total_result = num_total_result
self.callback_queue = multiprocessing.Queue(len_callback_queue)
def get_result(self):
return self.callback_queue.get()
def produce(self):
"""
Must return an iterable object or a Stop object.
"""
raise NotImplementedError
def consume(self, consumer_params):
"""
@return: task result or Stop()
"""
raise NotImplementedError
def callback(self, result):
return result
def _produce(self):
producer = self.produce()
if isinstance(producer, Iterable):
for params in producer:
self.task_queue.put(params, block=True)
stop = Stop()
for _ in range(self.num_consumer):
self.task_queue.put(stop, block=True)
elif isinstance(producer, Callable):
while True:
task = producer()
if isinstance(task, Stop):
break
self.task_queue.put(task, block=True)
def _consume(self):
consumer_params = self.task_queue.get(block=True)
while not isinstance(consumer_params, Stop):
info = self.consume(consumer_params)
self.result_queue.put(info)
consumer_params = self.task_queue.get(block=True)
self.result_queue.put(Stop())
def _callback(self):
if self.enable_progressbar:
bar = tqdm(total=self.num_total_result)
over_flag = 0
while over_flag < self.num_consumer:
result = self.result_queue.get(block=True)
if isinstance(result, Stop):
over_flag += 1
else:
callback = self.callback(result)
self.callback_queue.put(callback)
if self.enable_progressbar:
bar.update(1)
else:
if self.enable_progressbar:
bar.close()
def run(self):
consumers = []
callbackers = []
# 创建并启动生产者
for i in range(self.num_producer):
multiprocessing.Process(target=self._produce, name=f'producer_{i}').start()
# 创建并启动消费者
for i in range(self.num_consumer):
p = multiprocessing.Process(target=self._consume, name=f'consumer_{i}')
consumers.append(p)
p.start()
# 处理结果
if self.num_callback:
for i in range(self.num_callback):
p = multiprocessing.Process(target=self._callback, name=f'callback_{i}')
callbackers.append(p)
p.start()
return self
def close(self):
self.task_queue.close()
self.result_queue.close()
self.callback_queue.close()