diff --git a/app.py b/app.py new file mode 100644 index 0000000..8b3525b --- /dev/null +++ b/app.py @@ -0,0 +1,7 @@ +from flask import Flask + +from config import get_config + +config = get_config() + +app = Flask(__name__) diff --git a/config/__init__.py b/config/__init__.py index b2e49c5..1f12718 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -1 +1 @@ -import settings \ No newline at end of file +from .utils import get_config diff --git a/config/param.ini b/config/param.ini new file mode 100644 index 0000000..42a7e9c --- /dev/null +++ b/config/param.ini @@ -0,0 +1,12 @@ +[APIKey] +tronscan=cc87d361-7cd6-4f69-a57b-f0a77a213355 + +[PaymentAddresses] +usdt=TB592A5QwHvvcJoCmvALmzT3S9Pux91Gub + +[MYSQL] +user: 'your_mysql_username' +password: 'your_mysql_password' +host: 'localhost' +database: 'your_database_name' +autocommit: false diff --git a/config/settings.py b/config/settings.py deleted file mode 100644 index 4c79049..0000000 --- a/config/settings.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -@Project :payment -@File :__init__.py.py -@IDE :PyCharm -@Author :rengengchen -@Time :2024/11/06 16:11 -""" -import argparse -import os -import random -import sys -from argparse import Namespace -from configparser import ConfigParser - -import requests -from loguru import logger - -random_seed = 20240717 -ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - - -def setup_seed(seed): - # import torch - # torch.manual_seed(seed) - # torch.cuda.manual_seed_all(seed) - # torch.backends.cudnn.deterministic = True - # import numpy as np - # np.random.seed(seed) - random.seed(seed) - - -setup_seed(random_seed) - - -class Config: - def __init__(self, - config_parser: ConfigParser = None, - argument_parser: Namespace = None, - **kwargs): - if argument_parser: - for k in vars(argument_parser): - kwargs[k] = getattr(argument_parser, k) - if config_parser: - kwargs1 = kwargs.copy() - for section in config_parser.sections(): - section1 = kwargs.setdefault(section, Config(**kwargs1)) - for option in config_parser.options(section): - section1[option] = config_parser.get(section, option) - self.config = {k: Config(**v) if isinstance(v, dict) else v for k, v in kwargs.items()} - - def refresh(self, **kwargs): - kwargs = {k: Config(**v) if isinstance(v, dict) else v for k, v in kwargs.items()} - self.config.update(kwargs) - - def get(self, item, default=None): - if item in self.config: - return self.config[item] - return default - - def get_int(self, item): - v = self.get(item) - if v is not None: - return int(v) - - def get_float(self, item): - v = self.get(item) - if v is not None: - return float(v) - - def get_bool(self, item): - v = self.get(item) - if v is not None: - return bool(v) - - def __getitem__(self, item): - return self.config[item] - - def __setitem__(self, key, value): - self.config[key] = value - - def __getattr__(self, key): - if key in self.__dict__: - return self.__dict__[key] - try: - return self.config[key] - except KeyError: - raise AttributeError(f"'{self}' object has no attribute '{key}'") - - def __str__(self): - return str(self.config) - - -def log_config(config): - # fmt = '%(asctime)s [%(name)s] %(levelname)s: %(message)s' - # datefmt = "%Y-%m-%d %H:%M:%S" - logger.remove() - logger.add(sys.stdout, level=config.log_level) - logger.add(sys.stderr, level="ERROR") - logger.add(os.path.join(ROOT_DIR, "logs", "{time}.log"), level="DEBUG", encoding='utf8', rotation="100 MB", retention=3) - - -def get_config() -> Config: - requests.adapters.DEFAULT_RETRIES = 3 - - configparser = ConfigParser() - configparser.read(fr'{ROOT_DIR}/config_file/param.conf') - parser = argparse.ArgumentParser(description='payment system') - parser.add_argument("--seed", type=int, default=2024) - args = parser.parse_args() - config = Config(configparser, args) - return config diff --git a/config/utils.py b/config/utils.py new file mode 100644 index 0000000..4a6d836 --- /dev/null +++ b/config/utils.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :payment +@File :__init__.py.py +@IDE :PyCharm +@Author :rengengchen +@Time :2024/11/06 16:11 +""" +import argparse +import os +import random +import sys +from argparse import Namespace +from configparser import ConfigParser + +import requests +from loguru import logger + +random_seed = 20240717 +ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + + +def setup_seed(seed): + # import torch + # torch.manual_seed(seed) + # torch.cuda.manual_seed_all(seed) + # torch.backends.cudnn.deterministic = True + # import numpy as np + # np.random.seed(seed) + random.seed(seed) + + +setup_seed(random_seed) + + +class Setting: + def __init__(self, + config_parser: ConfigParser = None, + argument_parser: Namespace = None, + _visited=None, + _parent=None, + **kwargs): + self._parent = _parent + if config_parser: + for section in config_parser.sections(): + section_config = Setting(_parent=self) + for option in config_parser.options(section): + section_config[option] = config_parser.get(section, option) + self.__dict__[section] = section_config + + if argument_parser: + for k in vars(argument_parser): + self.__dict__[k] = getattr(argument_parser, k) + + if _visited is None: + _visited = set() + self.update(_visited=_visited, **kwargs) + + def update(self, _visited=None, **kwargs): + if _visited is None: + _visited = {} + for k, v in kwargs.items(): + cls_attr = getattr(self.__class__, k, None) + if callable(cls_attr): + raise KeyError(f"The key '{k}' conflicts with an existing class method. you can use {k}_ instead.") + if isinstance(v, dict): + obj_id = id(v) + if obj_id in _visited: + logger.warning(f"Circular reference detected in key: '{k}'.") + v = _visited[obj_id] + else: + v = Setting(_visited=_visited, _parent=self, **v) + _visited[obj_id] = v + self.__dict__[k] = v + + def get(self, item, default=None): + if item not in self.__dict__ and self._parent is not None: + return self._parent.get(item, default) + return default + + def get_int(self, item): + return int(self.get(item)) + + def get_float(self, item): + return float(self.get(item)) + + def get_bool(self, item): + return bool(self.get(item)) + + def set(self, item, value): + self.__dict__[item] = value + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def __getattr__(self, key): + return self.get(key) + + def __str__(self): + def _str_helper(config, indent=0, visited=None): + if visited is None: + visited = set() + lines = [] + indent_str = ' ' * indent + for key, value in config.__dict__.items(): + if key.startswith('_'): + continue + if isinstance(value, Setting): + if id(value) in visited: + lines.append(f"{indent_str}{key}: ") + else: + visited.add(id(value)) + lines.append(f"{indent_str}{key}:") + lines.append(_str_helper(value, indent + 1, visited)) + else: + lines.append(f"{indent_str}{key}: {value}") + return '\n'.join(lines) + + return _str_helper(self) + + +def log_config(config): + # fmt = '%(asctime)s [%(name)s] %(levelname)s: %(message)s' + # datefmt = "%Y-%m-%d %H:%M:%S" + logger.remove() + logger.add(sys.stdout, level=config.log_level) + logger.add(sys.stderr, level="ERROR") + logger.add(os.path.join(ROOT_DIR, "logs", "{time}.log"), level="DEBUG", encoding='utf8', rotation="100 MB", + retention=3) + + +def get_config(config_file=fr'{ROOT_DIR}/config/param.ini') -> Setting: + requests.adapters.DEFAULT_RETRIES = 3 + + configparser = ConfigParser() + configparser.read(config_file) + parser = argparse.ArgumentParser(description='payment system') + parser.add_argument("--seed", type=int, default=2024) + args = parser.parse_args() + config = Setting(configparser, args) + return config diff --git a/utils/db_connection.py b/database.py similarity index 100% rename from utils/db_connection.py rename to database.py diff --git a/models/__init__.py b/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/repositories/order.py b/repositories/order.py index cb92484..ac92ad8 100644 --- a/repositories/order.py +++ b/repositories/order.py @@ -1,20 +1,41 @@ from ruamel_yaml.util import create_timestamp from custom_decorators import singleton +from database import Database +from utils.datetime import current_timestamp @singleton class OrderRepository: - def __init__(self): - self.db = None + def __init__(self, config): + self.db = Database(config['MYSQL']) def create(self, order_id, from_address, to_address): - create_time = create_timestamp() - pass + cur_time = current_timestamp() + try: + self.db.execute_query( + "INSERT INTO orders (order_id, from_address, to_address, create_timestamp, update_timestamp) " + "VALUES (%s, %s, %s, %s, %s)", + [order_id, from_address, to_address, cur_time, cur_time] + ) + self.db.commit() + except Exception: + self.db.rollback() + raise def update_status(self, order_id, status): - # 更新状态和时间 - pass + try: + self.db.execute_query("UPDATE orders " + "SET status = %s, update_timestamp = %s " + "WHERE order_id = %s", + [status, current_timestamp(), order_id]) + self.db.commit() + except Exception: + self.db.rollback() + raise def get_order_info(self, order_id): - pass + self.db.execute_query("SELECT quant, from_address, to_address, create_timestamp " + "FROM orders " + "WHERE order_id = %s", + [order_id]) diff --git a/repositories/user.py b/repositories/user.py new file mode 100644 index 0000000..4eec1e9 --- /dev/null +++ b/repositories/user.py @@ -0,0 +1,40 @@ +from custom_decorators import singleton +from database import Database +from utils.database import pack_params + + +@singleton +class UserRepository: + def __init__(self, config): + self.db = Database(config['MYSQL']) + + def get_user(self, phone=None, email=None, address=None): + params_sql, params = pack_params(phone=phone, email=email, address=address) + sql = f"SELECT name, phone, email, address FROM user WHERE {params_sql}" + cursor = self.db.execute_query(sql, params) + users = cursor.fetchall() + return users + + def create_user(self, phone=None, email=None, address=None): + _, params = pack_params(phone=phone, email=email, address=address) + sql = f"INSERT INTO user (phone, email, address) VALUES (%s, %s, %s)" + self.db.execute_query(sql, params) + self.db.commit() + + def record_exists(self, phone=None, email=None, address=None): + params_sql, params = pack_params(phone=phone, email=email, address=address) + sql = f"SELECT EXISTS(SELECT 1 FROM user WHERE {params_sql} LIMIT 1)" + cursor = self.db.execute_query(sql, params) + result = cursor.fetchone() + return bool(result[0]) + + def create_if_not_exists(self, phone=None, email=None, address=None): + if not self.record_exists(phone=phone, email=email, address=address): + self.create_user(phone=phone, email=email, address=address) + + def get_and_create_if_not_exists(self, phone=None, email=None, address=None): + users = self.get_user(phone=phone, email=email, address=address) + if len(users) == 0: + self.create_user(phone=phone, email=email, address=address) + else: + return users diff --git a/services/order.py b/services/order.py index 3d944ec..b86a9a9 100644 --- a/services/order.py +++ b/services/order.py @@ -8,11 +8,11 @@ from utils.datetime import current, current_timestamp, is_time_difference_greate @singleton class OrderService: - def __init__(self): - self.order_repo = OrderRepository() + def __init__(self, config): + self.order_repo = OrderRepository(config) self.payment_service = PaymentService() - def create_order(self, from_address, to_address, *args, **kwargs): + def create_order(self, from_address, to_address): date_str = current().strftime('%Y%m%d%H%M%S') unique_id = str(uuid.uuid4()).split('-')[0] order_id = f"{date_str}-{unique_id}" @@ -21,14 +21,14 @@ class OrderService: def finish_order(self, order_id): # 判断支付时间是否超过订单存活时间 - quant, from_address, to_address, creation_timestamp = self.order_repo.get_order_info(order_id) + quant, from_address, to_address, create_timestamp = self.order_repo.get_order_info(order_id) current = current_timestamp() status = 0 - if is_time_difference_greater_than(creation_timestamp, current, minutes=15): + if is_time_difference_greater_than(create_timestamp, current, minutes=15): # 订单超时 status = 4 else: - correct_quant, confirmed = self.payment_service.check_payment(quant, from_address, to_address, creation_timestamp, current) + correct_quant, confirmed = self.payment_service.check_payment(quant, from_address, to_address, create_timestamp, current) if correct_quant and confirmed: # 支付成功 status = 1 diff --git a/services/user.py b/services/user.py new file mode 100644 index 0000000..7fbae0d --- /dev/null +++ b/services/user.py @@ -0,0 +1,6 @@ +from custom_decorators import singleton + + +@singleton +class UserService: + pass diff --git a/utils/database.py b/utils/database.py new file mode 100644 index 0000000..128d6a5 --- /dev/null +++ b/utils/database.py @@ -0,0 +1,10 @@ +def pack_params(**kwargs): + params = [] + param_sqls = [] + flag = False + for k, v in kwargs.items(): + flag = True + params.append(v) + param_sqls.append(f"{k}=%s") + if flag: + return " AND ".join(param_sqls), params