123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy.ext.declarative import declarative_base
- import os
- from dotenv import load_dotenv # type: ignore
- import logging
- from sqlalchemy.exc import SQLAlchemyError
- import logging
- # 开启SQLAlchemy的SQL执行日志(会打印所有执行的SQL语句和错误)
- logging.basicConfig()
- logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
- # 配置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- # 创建Base类
- Base = declarative_base()
- Base.metadata.clear()
- # 加载环境变量
- load_dotenv("config.env")
- # 从环境变量获取数据库连接信息
- DB_USER = os.getenv("DB_USER", "postgres")
- DB_PASSWORD = os.getenv("DB_PASSWORD", "scau2025")
- DB_HOST = os.getenv("DB_HOST", "localhost")
- DB_PORT = os.getenv("DB_PORT", "5432")
- DB_NAME = os.getenv("DB_NAME", "data_db")
- print(DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME)
- # 构建数据库连接URL
- SQLALCHEMY_DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}?options=-c client_encoding=utf8"
- def create_database_engine():
- """创建并配置数据库引擎"""
- try:
- engine = create_engine(
- SQLALCHEMY_DATABASE_URL,
- pool_pre_ping=True,
- pool_size=5,
- max_overflow=10,
- pool_timeout=30,
- pool_recycle=1800
- )
- return engine
- except Exception as e:
- logger.error(f"创建数据库引擎失败: {str(e)}")
- raise
- def test_database_connection(engine):
- """测试数据库连接"""
- try:
- with engine.connect() as conn:
- logger.info("数据库连接测试成功")
- return True
- except SQLAlchemyError as e:
- logger.error(f"数据库连接测试失败: {str(e)}")
- return False
- # 创建数据库引擎
- engine = create_database_engine()
- # 测试数据库连接
- if not test_database_connection(engine):
- raise Exception("无法连接到数据库,请检查数据库配置")
- # 创建会话工厂
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- def get_db():
- """获取数据库会话"""
- db = SessionLocal()
- try:
- yield db
- except SQLAlchemyError as e:
- logger.error(f"数据库操作错误: {str(e)}")
- raise
- finally:
- db.close()
- def execute_sql(sql_statement):
- """
- 执行原始SQL语句
-
- Args:
- sql_statement: 要执行的SQL语句
-
- Returns:
- 执行结果
- """
- try:
- with engine.begin() as connection:
- result = connection.execute(sql_statement)
- return result
- except SQLAlchemyError as e:
- logger.error(f"执行SQL语句失败: {str(e)}")
- raise
- # 新增:自动创建数据库表(关键!)
- def create_tables():
- try:
- # 必须导入所有模型,否则 Base 不知道要创建哪些表
- # 替换成你项目中实际的模型文件路径(根据你的目录结构调整)
- from app.models.orm_models import Base # 确保模型继承自这个 Base
- from app.models.vector import VectorData # 导入需要创建的表
-
- # 创建所有表
- Base.metadata.create_all(bind=engine)
- logger.info("数据库表自动创建成功!")
- except ImportError as e:
- logger.warning(f"未找到模型文件,可能需要手动创建表:{str(e)}")
- except Exception as e:
- logger.error(f"创建表失败:{str(e)}")
- raise
- # 执行建表(在连接测试成功后)
- create_tables()
|