mysqy_test.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import unittest
  2. import tools.db_helper.mysql_helper
  3. from tools.db_helper.base import Base
  4. from sqlalchemy import Column, Integer, String
  5. # 定义测试用的模型类
  6. class TestUser(Base):
  7. __tablename__ = "test_users"
  8. id = Column(Integer, primary_key=True)
  9. name = Column(String(50), nullable=False)
  10. email = Column(String(100), unique=True)
  11. class TestMySQLHelper(unittest.TestCase):
  12. @classmethod
  13. def setUpClass(cls):
  14. """测试类初始化"""
  15. cls.db_helper = tools.db_helper.mysql_helper.MySQLHelper()
  16. # 设置测试数据库配置
  17. cls.test_config = {
  18. "host": "localhost",
  19. "port": 3306,
  20. "user": "test_user",
  21. "password": "test_password",
  22. "db": "test_db",
  23. "charset": "utf8mb4",
  24. }
  25. cls.test_db = "test_db"
  26. def setUp(self):
  27. """每个测试用例执行前的设置"""
  28. # 创建测试表
  29. engine = self.db_helper.get_engine(self.test_db, self.test_config)
  30. Base.metadata.create_all(engine)
  31. def tearDown(self):
  32. """每个测试用例执行后的清理"""
  33. # 删除测试表
  34. engine = self.db_helper.get_engine(self.test_db, self.test_config)
  35. Base.metadata.drop_all(engine)
  36. self.db_helper.dispose_all()
  37. def test_singleton(self):
  38. """测试单例模式"""
  39. mysql1 = tools.db_helper.mysql_helper.MySQLHelper()
  40. mysql2 = tools.db_helper.mysql_helper.MySQLHelper()
  41. self.assertIs(mysql1, mysql2)
  42. def test_set_default_config(self):
  43. """测试设置默认配置"""
  44. test_config = {"host": "test_host", "port": 3307}
  45. self.db_helper.set_default_config(test_config)
  46. self.assertEqual(self.db_helper._default_config["host"], "test_host")
  47. self.assertEqual(self.db_helper._default_config["port"], 3307)
  48. def test_get_config_for_database(self):
  49. """测试获取数据库配置"""
  50. try:
  51. config = self.db_helper.get_config_for_database(self.test_db)
  52. self.assertIsInstance(config, dict)
  53. except Exception as e:
  54. self.fail(f"获取数据库配置失败: {str(e)}")
  55. def test_execute_query(self):
  56. """测试查询操作"""
  57. try:
  58. # 创建测试表
  59. create_table_sql = """
  60. CREATE TABLE IF NOT EXISTS test_table (
  61. id INT PRIMARY KEY AUTO_INCREMENT,
  62. name VARCHAR(50) NOT NULL
  63. )"""
  64. self.db_helper.execute_non_query(self.test_db, create_table_sql)
  65. # 插入测试数据
  66. insert_sql = "INSERT INTO test_table (name) VALUES (%s)"
  67. self.db_helper.execute_non_query(self.test_db, insert_sql, ("test_name",))
  68. # 测试查询
  69. query_sql = "SELECT * FROM test_table WHERE name = %s"
  70. results = self.db_helper.execute_query(
  71. self.test_db, query_sql, ("test_name",)
  72. )
  73. self.assertTrue(len(results) > 0)
  74. self.assertEqual(results[0][1], "test_name")
  75. except Exception as e:
  76. self.fail(f"查询操作测试失败: {str(e)}")
  77. def test_execute_scalar(self):
  78. """测试标量查询"""
  79. try:
  80. # 创建测试表并插入数据
  81. self.db_helper.execute_non_query(
  82. self.test_db,
  83. """
  84. CREATE TABLE IF NOT EXISTS test_scalar (
  85. id INT PRIMARY KEY AUTO_INCREMENT,
  86. value INT NOT NULL
  87. )
  88. """,
  89. )
  90. self.db_helper.execute_non_query(
  91. self.test_db, "INSERT INTO test_scalar (value) VALUES (%s)", (42,)
  92. )
  93. # 测试标量查询
  94. result = self.db_helper.execute_scalar(
  95. self.test_db, "SELECT value FROM test_scalar WHERE id = 1"
  96. )
  97. self.assertEqual(result, 42)
  98. except Exception as e:
  99. self.fail(f"标量查询测试失败: {str(e)}")
  100. def test_session_scope(self):
  101. """测试会话作用域和事务管理"""
  102. try:
  103. # 测试成功的事务
  104. with self.db_helper.session_scope(
  105. self.test_db, self.test_config
  106. ) as session:
  107. user = TestUser(name="test_user", email="test@example.com")
  108. session.add(user)
  109. # 验证数据已保存
  110. with self.db_helper.session_scope(
  111. self.test_db, self.test_config
  112. ) as session:
  113. saved_user = session.query(TestUser).filter_by(name="test_user").first()
  114. self.assertIsNotNone(saved_user)
  115. self.assertEqual(saved_user.email, "test@example.com")
  116. # 测试事务回滚
  117. with self.assertRaises(Exception):
  118. with self.db_helper.session_scope(
  119. self.test_db, self.test_config
  120. ) as session:
  121. user = TestUser(name="rollback_user", email="invalid_email")
  122. session.add(user)
  123. raise Exception("测试回滚")
  124. # 验证数据已回滚
  125. with self.db_helper.session_scope(
  126. self.test_db, self.test_config
  127. ) as session:
  128. rollback_user = (
  129. session.query(TestUser).filter_by(name="rollback_user").first()
  130. )
  131. self.assertIsNone(rollback_user)
  132. except Exception as e:
  133. self.fail(f"会话作用域测试失败: {str(e)}")
  134. if __name__ == "__main__":
  135. unittest.main()