config_helper.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import yaml
  3. class ConfigHelper:
  4. _instance = None
  5. # 默认配置文件路径
  6. default_config_path = os.path.join(os.path.dirname(__file__), "..", "config.yml")
  7. # 类变量存储加载的配置
  8. _config = None
  9. _path = None
  10. def __new__(cls, *args, **kwargs):
  11. if not cls._instance:
  12. cls._instance = super(ConfigHelper, cls).__new__(cls)
  13. return cls._instance
  14. def load_config(self, path=None):
  15. if self._config is None:
  16. if not path:
  17. # print(f"使用默认配置文件:{self.default_config_path}")
  18. self._path = self.default_config_path
  19. else:
  20. self._path = path
  21. if not os.path.exists(self._path):
  22. raise FileNotFoundError(f"没有找到文件或目录:'{self._path}'")
  23. with open(self._path, "r", encoding="utf-8") as file:
  24. self._config = yaml.safe_load(file)
  25. # 合并环境变量配置
  26. self._merge_env_vars()
  27. # print(f"加载的配置文件内容:{self._config}")
  28. return self._config
  29. def _merge_env_vars(self, env_prefix="APP_"): # 环境变量前缀为 APP_
  30. for key, value in os.environ.items():
  31. if key.startswith(env_prefix):
  32. config_key = key[len(env_prefix) :].lower()
  33. self._set_nested_key(self._config, config_key.split("__"), value)
  34. def _set_nested_key(self, config, keys, value):
  35. if len(keys) > 1:
  36. if keys[0] not in config or not isinstance(config[keys[0]], dict):
  37. config[keys[0]] = {}
  38. self._set_nested_key(config[keys[0]], keys[1:], value)
  39. else:
  40. config[keys[0]] = value
  41. def get(self, key: str, default: str = None):
  42. if self._config is None:
  43. self.load_config(self._path)
  44. keys = key.split(".")
  45. config = self._config
  46. for k in keys:
  47. if isinstance(config, dict) and k in config:
  48. config = config[k]
  49. else:
  50. return default
  51. return config
  52. def get_bool(self, key: str) -> bool:
  53. val = str(self.get(key, "0"))
  54. return True if val.lower() == "true" or val == "1" else False
  55. def get_int(self, key: str, default: int = 0) -> int:
  56. val = self.get(key)
  57. if not val:
  58. return default
  59. try:
  60. return int(val)
  61. except ValueError:
  62. return default
  63. def get_object(self, key: str, default: dict = None):
  64. val = self.get(key)
  65. if not val:
  66. return default
  67. if isinstance(val, dict):
  68. return val
  69. try:
  70. return yaml.safe_load(val)
  71. except yaml.YAMLError as e:
  72. print(f"Error loading YAML object: {e}")
  73. return default
  74. def get_all(self):
  75. if self._config is None:
  76. self.load_config(self._path)
  77. return self._config