128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: UTF-8 -*-
|
||
"""
|
||
@Project :IoD_data_analysis_tool
|
||
@File :producer_consumer.py
|
||
@IDE :PyCharm
|
||
@Author :rengengchen
|
||
@Time :2022/8/5 11:53
|
||
"""
|
||
import multiprocessing
|
||
from typing import Iterable, Callable
|
||
|
||
from tqdm import tqdm
|
||
|
||
|
||
class Stop:
|
||
pass
|
||
|
||
|
||
class AbstractPCConcurrencySystem:
|
||
"""
|
||
@todo 对启动进程的维护
|
||
@todo 进程数量
|
||
"""
|
||
|
||
def __init__(self, num_producer: int = 1, num_consumer: int = 1, num_callback: int = 0,
|
||
len_task_queue: int = 0, len_result_queue: int = 0, len_callback_queue: int = 0,
|
||
producer_lock=None, consumer_lock=None, callback_lock=None,
|
||
meta=None, enable_progressbar=False, num_total_result=None):
|
||
self.task_queue = multiprocessing.Queue(len_task_queue)
|
||
|
||
self.num_producer = num_producer
|
||
self.num_consumer = num_consumer
|
||
self.num_callback = num_callback
|
||
self.producer_lock = producer_lock or multiprocessing.Lock()
|
||
self.consumer_lock = consumer_lock or multiprocessing.Lock()
|
||
self.meta = meta
|
||
self.enable_progressbar = enable_progressbar
|
||
if enable_progressbar and self.num_callback == 0:
|
||
self.num_callback = 1
|
||
self.result_queue = multiprocessing.Queue(len_result_queue)
|
||
if self.num_callback:
|
||
self.callback_lock = callback_lock or multiprocessing.Lock()
|
||
self.num_total_result = num_total_result
|
||
self.callback_queue = multiprocessing.Queue(len_callback_queue)
|
||
|
||
def get_result(self):
|
||
return self.callback_queue.get()
|
||
|
||
def produce(self):
|
||
"""
|
||
Must return an iterable object or a Stop object.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def consume(self, consumer_params):
|
||
"""
|
||
@return: task result or Stop()
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def callback(self, result):
|
||
return result
|
||
|
||
def _produce(self):
|
||
producer = self.produce()
|
||
if isinstance(producer, Iterable):
|
||
for params in producer:
|
||
self.task_queue.put(params, block=True)
|
||
stop = Stop()
|
||
for _ in range(self.num_consumer):
|
||
self.task_queue.put(stop, block=True)
|
||
elif isinstance(producer, Callable):
|
||
while True:
|
||
task = producer()
|
||
if isinstance(task, Stop):
|
||
break
|
||
self.task_queue.put(task, block=True)
|
||
|
||
def _consume(self):
|
||
consumer_params = self.task_queue.get(block=True)
|
||
while not isinstance(consumer_params, Stop):
|
||
info = self.consume(consumer_params)
|
||
self.result_queue.put(info)
|
||
consumer_params = self.task_queue.get(block=True)
|
||
self.result_queue.put(Stop())
|
||
|
||
def _callback(self):
|
||
if self.enable_progressbar:
|
||
bar = tqdm(total=self.num_total_result)
|
||
over_flag = 0
|
||
while over_flag < self.num_consumer:
|
||
result = self.result_queue.get(block=True)
|
||
if isinstance(result, Stop):
|
||
over_flag += 1
|
||
else:
|
||
callback = self.callback(result)
|
||
self.callback_queue.put(callback)
|
||
if self.enable_progressbar:
|
||
bar.update(1)
|
||
else:
|
||
if self.enable_progressbar:
|
||
bar.close()
|
||
|
||
def run(self):
|
||
consumers = []
|
||
callbackers = []
|
||
# 创建并启动生产者
|
||
for i in range(self.num_producer):
|
||
multiprocessing.Process(target=self._produce, name=f'producer_{i}').start()
|
||
# 创建并启动消费者
|
||
for i in range(self.num_consumer):
|
||
p = multiprocessing.Process(target=self._consume, name=f'consumer_{i}')
|
||
consumers.append(p)
|
||
p.start()
|
||
# 处理结果
|
||
if self.num_callback:
|
||
for i in range(self.num_callback):
|
||
p = multiprocessing.Process(target=self._callback, name=f'callback_{i}')
|
||
callbackers.append(p)
|
||
p.start()
|
||
return self
|
||
|
||
def close(self):
|
||
self.task_queue.close()
|
||
self.result_queue.close()
|
||
self.callback_queue.close()
|