util/lib/analysis_package/code_template/concurrency/producer_consumer.py

128 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project IoD_data_analysis_tool
@File producer_consumer.py
@IDE PyCharm
@Author rengengchen
@Time 2022/8/5 11:53
"""
import multiprocessing
from typing import Iterable, Callable
from tqdm import tqdm
class Stop:
pass
class AbstractPCConcurrencySystem:
"""
@todo 对启动进程的维护
@todo 进程数量
"""
def __init__(self, num_producer: int = 1, num_consumer: int = 1, num_callback: int = 0,
len_task_queue: int = 0, len_result_queue: int = 0, len_callback_queue: int = 0,
producer_lock=None, consumer_lock=None, callback_lock=None,
meta=None, enable_progressbar=False, num_total_result=None):
self.task_queue = multiprocessing.Queue(len_task_queue)
self.num_producer = num_producer
self.num_consumer = num_consumer
self.num_callback = num_callback
self.producer_lock = producer_lock or multiprocessing.Lock()
self.consumer_lock = consumer_lock or multiprocessing.Lock()
self.meta = meta
self.enable_progressbar = enable_progressbar
if enable_progressbar and self.num_callback == 0:
self.num_callback = 1
self.result_queue = multiprocessing.Queue(len_result_queue)
if self.num_callback:
self.callback_lock = callback_lock or multiprocessing.Lock()
self.num_total_result = num_total_result
self.callback_queue = multiprocessing.Queue(len_callback_queue)
def get_result(self):
return self.callback_queue.get()
def produce(self):
"""
Must return an iterable object or a Stop object.
"""
raise NotImplementedError
def consume(self, consumer_params):
"""
@return: task result or Stop()
"""
raise NotImplementedError
def callback(self, result):
return result
def _produce(self):
producer = self.produce()
if isinstance(producer, Iterable):
for params in producer:
self.task_queue.put(params, block=True)
stop = Stop()
for _ in range(self.num_consumer):
self.task_queue.put(stop, block=True)
elif isinstance(producer, Callable):
while True:
task = producer()
if isinstance(task, Stop):
break
self.task_queue.put(task, block=True)
def _consume(self):
consumer_params = self.task_queue.get(block=True)
while not isinstance(consumer_params, Stop):
info = self.consume(consumer_params)
self.result_queue.put(info)
consumer_params = self.task_queue.get(block=True)
self.result_queue.put(Stop())
def _callback(self):
if self.enable_progressbar:
bar = tqdm(total=self.num_total_result)
over_flag = 0
while over_flag < self.num_consumer:
result = self.result_queue.get(block=True)
if isinstance(result, Stop):
over_flag += 1
else:
callback = self.callback(result)
self.callback_queue.put(callback)
if self.enable_progressbar:
bar.update(1)
else:
if self.enable_progressbar:
bar.close()
def run(self):
consumers = []
callbackers = []
# 创建并启动生产者
for i in range(self.num_producer):
multiprocessing.Process(target=self._produce, name=f'producer_{i}').start()
# 创建并启动消费者
for i in range(self.num_consumer):
p = multiprocessing.Process(target=self._consume, name=f'consumer_{i}')
consumers.append(p)
p.start()
# 处理结果
if self.num_callback:
for i in range(self.num_callback):
p = multiprocessing.Process(target=self._callback, name=f'callback_{i}')
callbackers.append(p)
p.start()
return self
def close(self):
self.task_queue.close()
self.result_queue.close()
self.callback_queue.close()