123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import unittest
- import tools.db_helper.mysql_helper
- from tools.db_helper.base import Base
- from sqlalchemy import Column, Integer, String
- # 定义测试用的模型类
- class TestUser(Base):
- __tablename__ = "test_users"
- id = Column(Integer, primary_key=True)
- name = Column(String(50), nullable=False)
- email = Column(String(100), unique=True)
- class TestMySQLHelper(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- """测试类初始化"""
- cls.db_helper = tools.db_helper.mysql_helper.MySQLHelper()
- # 设置测试数据库配置
- cls.test_config = {
- "host": "localhost",
- "port": 3306,
- "user": "test_user",
- "password": "test_password",
- "db": "test_db",
- "charset": "utf8mb4",
- }
- cls.test_db = "test_db"
- def setUp(self):
- """每个测试用例执行前的设置"""
- # 创建测试表
- engine = self.db_helper.get_engine(self.test_db, self.test_config)
- Base.metadata.create_all(engine)
- def tearDown(self):
- """每个测试用例执行后的清理"""
- # 删除测试表
- engine = self.db_helper.get_engine(self.test_db, self.test_config)
- Base.metadata.drop_all(engine)
- self.db_helper.dispose_all()
- def test_singleton(self):
- """测试单例模式"""
- mysql1 = tools.db_helper.mysql_helper.MySQLHelper()
- mysql2 = tools.db_helper.mysql_helper.MySQLHelper()
- self.assertIs(mysql1, mysql2)
- def test_set_default_config(self):
- """测试设置默认配置"""
- test_config = {"host": "test_host", "port": 3307}
- self.db_helper.set_default_config(test_config)
- self.assertEqual(self.db_helper._default_config["host"], "test_host")
- self.assertEqual(self.db_helper._default_config["port"], 3307)
- def test_get_config_for_database(self):
- """测试获取数据库配置"""
- try:
- config = self.db_helper.get_config_for_database(self.test_db)
- self.assertIsInstance(config, dict)
- except Exception as e:
- self.fail(f"获取数据库配置失败: {str(e)}")
- def test_execute_query(self):
- """测试查询操作"""
- try:
- # 创建测试表
- create_table_sql = """
- CREATE TABLE IF NOT EXISTS test_table (
- id INT PRIMARY KEY AUTO_INCREMENT,
- name VARCHAR(50) NOT NULL
- )"""
- self.db_helper.execute_non_query(self.test_db, create_table_sql)
- # 插入测试数据
- insert_sql = "INSERT INTO test_table (name) VALUES (%s)"
- self.db_helper.execute_non_query(self.test_db, insert_sql, ("test_name",))
- # 测试查询
- query_sql = "SELECT * FROM test_table WHERE name = %s"
- results = self.db_helper.execute_query(
- self.test_db, query_sql, ("test_name",)
- )
- self.assertTrue(len(results) > 0)
- self.assertEqual(results[0][1], "test_name")
- except Exception as e:
- self.fail(f"查询操作测试失败: {str(e)}")
- def test_execute_scalar(self):
- """测试标量查询"""
- try:
- # 创建测试表并插入数据
- self.db_helper.execute_non_query(
- self.test_db,
- """
- CREATE TABLE IF NOT EXISTS test_scalar (
- id INT PRIMARY KEY AUTO_INCREMENT,
- value INT NOT NULL
- )
- """,
- )
- self.db_helper.execute_non_query(
- self.test_db, "INSERT INTO test_scalar (value) VALUES (%s)", (42,)
- )
- # 测试标量查询
- result = self.db_helper.execute_scalar(
- self.test_db, "SELECT value FROM test_scalar WHERE id = 1"
- )
- self.assertEqual(result, 42)
- except Exception as e:
- self.fail(f"标量查询测试失败: {str(e)}")
- def test_session_scope(self):
- """测试会话作用域和事务管理"""
- try:
- # 测试成功的事务
- with self.db_helper.session_scope(
- self.test_db, self.test_config
- ) as session:
- user = TestUser(name="test_user", email="test@example.com")
- session.add(user)
- # 验证数据已保存
- with self.db_helper.session_scope(
- self.test_db, self.test_config
- ) as session:
- saved_user = session.query(TestUser).filter_by(name="test_user").first()
- self.assertIsNotNone(saved_user)
- self.assertEqual(saved_user.email, "test@example.com")
- # 测试事务回滚
- with self.assertRaises(Exception):
- with self.db_helper.session_scope(
- self.test_db, self.test_config
- ) as session:
- user = TestUser(name="rollback_user", email="invalid_email")
- session.add(user)
- raise Exception("测试回滚")
- # 验证数据已回滚
- with self.db_helper.session_scope(
- self.test_db, self.test_config
- ) as session:
- rollback_user = (
- session.query(TestUser).filter_by(name="rollback_user").first()
- )
- self.assertIsNone(rollback_user)
- except Exception as e:
- self.fail(f"会话作用域测试失败: {str(e)}")
- if __name__ == "__main__":
- unittest.main()
|