database.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from sqlalchemy import create_engine
  2. from sqlalchemy.orm import sessionmaker
  3. from sqlalchemy.ext.declarative import declarative_base
  4. import os
  5. from dotenv import load_dotenv # type: ignore
  6. import logging
  7. from sqlalchemy.exc import SQLAlchemyError
  8. # 配置日志系统 - 检查是否已经配置过,避免重复配置
  9. if not logging.getLogger().handlers:
  10. logging.basicConfig(level=logging.INFO)
  11. # 关闭SQLAlchemy的详细SQL执行日志,只保留错误日志
  12. logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
  13. logger = logging.getLogger(__name__)
  14. # 创建Base类
  15. Base = declarative_base()
  16. Base.metadata.clear()
  17. # 加载环境变量
  18. load_dotenv("config.env")
  19. # 从环境变量获取数据库连接信息
  20. DB_USER = os.getenv("DB_USER", "postgres")
  21. DB_PASSWORD = os.getenv("DB_PASSWORD", "scau2025")
  22. DB_HOST = os.getenv("DB_HOST", "localhost")
  23. DB_PORT = os.getenv("DB_PORT", "5432")
  24. DB_NAME = os.getenv("DB_NAME", "data_db")
  25. print(DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME)
  26. # 构建数据库连接URL
  27. SQLALCHEMY_DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}?options=-c client_encoding=utf8"
  28. def create_database_engine():
  29. """创建并配置数据库引擎"""
  30. try:
  31. engine = create_engine(
  32. SQLALCHEMY_DATABASE_URL,
  33. pool_pre_ping=True,
  34. pool_size=5,
  35. max_overflow=10,
  36. pool_timeout=30,
  37. pool_recycle=1800
  38. )
  39. return engine
  40. except Exception as e:
  41. logger.error(f"创建数据库引擎失败: {str(e)}")
  42. raise
  43. def test_database_connection(engine):
  44. """测试数据库连接"""
  45. try:
  46. with engine.connect() as conn:
  47. logger.info("数据库连接测试成功")
  48. return True
  49. except SQLAlchemyError as e:
  50. logger.error(f"数据库连接测试失败: {str(e)}")
  51. return False
  52. # 创建数据库引擎
  53. engine = create_database_engine()
  54. # 测试数据库连接
  55. if not test_database_connection(engine):
  56. raise Exception("无法连接到数据库,请检查数据库配置")
  57. # 创建会话工厂
  58. SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  59. def get_db():
  60. """获取数据库会话"""
  61. db = SessionLocal()
  62. try:
  63. yield db
  64. except SQLAlchemyError as e:
  65. logger.error(f"数据库操作错误: {str(e)}")
  66. raise
  67. finally:
  68. db.close()
  69. def execute_sql(sql_statement):
  70. """
  71. 执行原始SQL语句
  72. Args:
  73. sql_statement: 要执行的SQL语句
  74. Returns:
  75. 执行结果
  76. """
  77. try:
  78. with engine.begin() as connection:
  79. result = connection.execute(sql_statement)
  80. return result
  81. except SQLAlchemyError as e:
  82. logger.error(f"执行SQL语句失败: {str(e)}")
  83. raise
  84. # 新增:自动创建数据库表(关键!)
  85. def create_tables():
  86. try:
  87. # 必须导入所有模型,否则 Base 不知道要创建哪些表
  88. # 替换成你项目中实际的模型文件路径(根据你的目录结构调整)
  89. from app.models.orm_models import Base # 确保模型继承自这个 Base
  90. from app.models.vector import VectorData # 导入需要创建的表
  91. # 创建所有表
  92. Base.metadata.create_all(bind=engine)
  93. logger.info("数据库表自动创建成功!")
  94. except ImportError as e:
  95. logger.warning(f"未找到模型文件,可能需要手动创建表:{str(e)}")
  96. except Exception as e:
  97. logger.error(f"创建表失败:{str(e)}")
  98. raise
  99. # 执行建表(在连接测试成功后)
  100. create_tables()