#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :payment @File :__init__.py.py @IDE :PyCharm @Author :rengengchen @Time :2024/11/06 16:11 """ import argparse import os import random import sys from argparse import Namespace from configparser import ConfigParser import requests import yaml from loguru import logger random_seed = 20240717 ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) def setup_seed(seed): # import torch # torch.manual_seed(seed) # torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # import numpy as np # np.random.seed(seed) random.seed(seed) setup_seed(random_seed) class Setting: def __init__(self, settings: dict = None, config_parser: ConfigParser = None, argument_parser: Namespace = None, _visited=None, _parent=None, **kwargs): self._parent = _parent if settings is not None: self.update(_visited=_visited, **settings) if config_parser: self.update_config_parser(config_parser) if argument_parser: self.update_argument_parser(argument_parser) self.update(_visited=_visited, **kwargs) def update(self, _visited=None, **kwargs): if _visited is None: _visited = {} for k, v in kwargs.items(): cls_attr = getattr(self.__class__, k, None) if callable(cls_attr): raise KeyError(f"The key '{k}' conflicts with an existing class method. you can use {k}_ instead.") if k.startswith('_'): raise KeyError(f"The key '{k}' is a private attribute.") if isinstance(v, dict): obj_id = id(v) if obj_id in _visited: logger.warning(f"Circular reference detected in key: '{k}'.") v = _visited[obj_id] else: v = Setting(_visited=_visited, _parent=self, **v) _visited[obj_id] = v self[k] = v def update_config_parser(self, config_parser: ConfigParser): for section in config_parser.sections(): section_config = Setting(_parent=self) for option in config_parser.options(section): section_config[option] = config_parser.get(section, option) self[section] = section_config def update_argument_parser(self, argument_parser: Namespace): for k in vars(argument_parser): self[k] = getattr(argument_parser, k) def get(self, item, default=None): item = item.lower() if item not in self.__dict__ and self._parent is not None: return self._parent.get(item, default) return self.__dict__.get(item, default) def get_int(self, item): return int(self.get(item)) def get_float(self, item): return float(self.get(item)) def get_bool(self, item): return bool(self.get(item)) def set(self, item, value): self.__dict__[item] = value def __getitem__(self, item): return self.__dict__[item.lower()] def __setitem__(self, key, value): self.__dict__[key.lower()] = value def __getattr__(self, key): return self[key] def keys(self): return set(k for k in self.__dict__.keys() if not k.startswith('_')) def __iter__(self): for k, v in self.__dict__.items(): yield k, v def __contains__(self, item): if item is None: raise ValueError('None is not a valid key') return item.lower() in self.__dict__ def __str__(self): def _str_helper(settings, visited=None, indent_count=0): if visited is None: visited = set() lines = [] indent_str = ' ' * indent_count for key, value in settings.__dict__.items(): if key.startswith('_'): continue if isinstance(value, Setting): if id(value) in visited: lines.append(f"{indent_str}{key}: ") else: visited.add(id(value)) lines.append(f"{indent_str}{key}:") lines.append(_str_helper(value, visited=visited, indent_count=indent_count + 1)) else: lines.append(f"{indent_str}{key}: {value}") return '\n'.join(lines) return _str_helper(self) def __repr__(self): def _str_helper(settings, visited=None, indent_count=0): if visited is None: visited = set() lines = [] for key, value in settings.__dict__.items(): if key.startswith('_'): continue if isinstance(value, Setting): if id(value) in visited: lines.append(f"{key}=") else: visited.add(id(value)) lines.append(f"{key}={_str_helper(value, visited=visited, indent_count=indent_count + 1)}") else: lines.append(f"{key}={value}") return f'Setting({", ".join(lines)})' return _str_helper(self) def log_config(config): # fmt = '%(asctime)s [%(name)s] %(levelname)s: %(message)s' # datefmt = "%Y-%m-%d %H:%M:%S" logger.remove() logger.add(sys.stdout, level=config.log_level) logger.add(sys.stderr, level="ERROR") logger.add(os.path.join(ROOT_DIR, "logs", "{time}.log"), level="DEBUG", encoding='utf8', rotation="100 MB", retention=3) def get_config(yaml_file=fr'{ROOT_DIR}/config/param.yaml', ini_file=fr'{ROOT_DIR}/config/param.ini') -> Setting: requests.adapters.DEFAULT_RETRIES = 3 config = Setting() if os.path.exists(ini_file): configparser = ConfigParser() configparser.read(ini_file) config.update_config_parser(configparser) parser = argparse.ArgumentParser(description='payment system') parser.add_argument("--seed", type=int, default=2024) args = parser.parse_args() config.update_argument_parser(args) if os.path.exists(yaml_file): with open(yaml_file, 'r') as file: data = yaml.safe_load(file) config.update(**data) return config