payment/payment_backend/config/utils.py

146 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project payment
@File __init__.py.py
@IDE PyCharm
@Author rengengchen
@Time 2024/11/06 16:11
"""
import argparse
import os
import random
import sys
from argparse import Namespace
from configparser import ConfigParser
import requests
from loguru import logger
random_seed = 20240717
ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def setup_seed(seed):
# import torch
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# import numpy as np
# np.random.seed(seed)
random.seed(seed)
setup_seed(random_seed)
class Setting:
def __init__(self,
config_parser: ConfigParser = None,
argument_parser: Namespace = None,
_visited=None,
_parent=None,
**kwargs):
self._parent = _parent
if config_parser:
for section in config_parser.sections():
section_config = Setting(_parent=self)
for option in config_parser.options(section):
section_config[option] = config_parser.get(section, option)
self.__dict__[section] = section_config
if argument_parser:
for k in vars(argument_parser):
self.__dict__[k] = getattr(argument_parser, k)
if _visited is None:
_visited = set()
self.update(_visited=_visited, **kwargs)
def update(self, _visited=None, **kwargs):
if _visited is None:
_visited = {}
for k, v in kwargs.items():
cls_attr = getattr(self.__class__, k, None)
if callable(cls_attr):
raise KeyError(f"The key '{k}' conflicts with an existing class method. you can use {k}_ instead.")
if isinstance(v, dict):
obj_id = id(v)
if obj_id in _visited:
logger.warning(f"Circular reference detected in key: '{k}'.")
v = _visited[obj_id]
else:
v = Setting(_visited=_visited, _parent=self, **v)
_visited[obj_id] = v
self.__dict__[k] = v
def get(self, item, default=None):
if item not in self.__dict__ and self._parent is not None:
return self._parent.get(item, default)
return default
def get_int(self, item):
return int(self.get(item))
def get_float(self, item):
return float(self.get(item))
def get_bool(self, item):
return bool(self.get(item))
def set(self, item, value):
self.__dict__[item] = value
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, key, value):
self.__dict__[key] = value
def __getattr__(self, key):
return self.get(key)
def __str__(self):
def _str_helper(config, indent=0, visited=None):
if visited is None:
visited = set()
lines = []
indent_str = ' ' * indent
for key, value in config.__dict__.items():
if key.startswith('_'):
continue
if isinstance(value, Setting):
if id(value) in visited:
lines.append(f"{indent_str}{key}: <Circular Reference>")
else:
visited.add(id(value))
lines.append(f"{indent_str}{key}:")
lines.append(_str_helper(value, indent + 1, visited))
else:
lines.append(f"{indent_str}{key}: {value}")
return '\n'.join(lines)
return _str_helper(self)
def log_config(config):
# fmt = '%(asctime)s [%(name)s] %(levelname)s: %(message)s'
# datefmt = "%Y-%m-%d %H:%M:%S"
logger.remove()
logger.add(sys.stdout, level=config.log_level)
logger.add(sys.stderr, level="ERROR")
logger.add(os.path.join(ROOT_DIR, "logs", "{time}.log"), level="DEBUG", encoding='utf8', rotation="100 MB",
retention=3)
def get_config(config_file=fr'{ROOT_DIR}/config/param.ini') -> Setting:
requests.adapters.DEFAULT_RETRIES = 3
configparser = ConfigParser()
configparser.read(config_file)
parser = argparse.ArgumentParser(description='payment system')
parser.add_argument("--seed", type=int, default=2024)
args = parser.parse_args()
config = Setting(configparser, args)
return config