database.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from sqlalchemy import create_engine , text
  2. from app.models.orm_models import Base
  3. from sqlalchemy.orm import sessionmaker
  4. from sqlalchemy.ext.declarative import declarative_base
  5. import os
  6. from dotenv import load_dotenv # type: ignore
  7. import logging
  8. from sqlalchemy.exc import SQLAlchemyError
  9. import logging
  10. # 开启SQLAlchemy的SQL执行日志(会打印所有执行的SQL语句和错误)
  11. logging.basicConfig()
  12. logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
  13. # 配置日志
  14. logging.basicConfig(level=logging.INFO)
  15. logger = logging.getLogger(__name__)
  16. # 创建Base类
  17. #Base = declarative_base()
  18. Base.metadata.clear()
  19. # 加载环境变量
  20. load_dotenv("config.env")
  21. # 从环境变量获取数据库连接信息
  22. DB_USER = os.getenv("DB_USER", "postgres")
  23. DB_PASSWORD = os.getenv("DB_PASSWORD", "123456")
  24. DB_HOST = os.getenv("DB_HOST", "localhost")
  25. DB_PORT = os.getenv("DB_PORT", "5432")
  26. DB_NAME = os.getenv("DB_NAME", "data_db")
  27. print(DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME)
  28. # 构建数据库连接URL
  29. SQLALCHEMY_DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}?options=-c client_encoding=utf8"
  30. def create_database_engine():
  31. """创建并配置数据库引擎"""
  32. try:
  33. engine = create_engine(
  34. SQLALCHEMY_DATABASE_URL,
  35. pool_pre_ping=True,
  36. pool_size=5,
  37. max_overflow=10,
  38. pool_timeout=30,
  39. pool_recycle=1800
  40. )
  41. return engine
  42. except Exception as e:
  43. logger.error(f"创建数据库引擎失败: {str(e)}")
  44. raise
  45. def test_database_connection(engine):
  46. """测试数据库连接"""
  47. try:
  48. with engine.connect() as conn:
  49. logger.info("数据库连接测试成功")
  50. return True
  51. except SQLAlchemyError as e:
  52. logger.error(f"数据库连接测试失败: {str(e)}")
  53. return False
  54. # 创建数据库引擎
  55. engine = create_database_engine()
  56. # 测试数据库连接
  57. if not test_database_connection(engine):
  58. raise Exception("无法连接到数据库,请检查数据库配置")
  59. # 创建会话工厂
  60. SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  61. def get_db():
  62. """获取数据库会话"""
  63. db = SessionLocal()
  64. try:
  65. yield db
  66. except SQLAlchemyError as e:
  67. logger.error(f"数据库操作错误: {str(e)}")
  68. raise
  69. finally:
  70. db.close()
  71. def execute_sql(sql_statement):
  72. """
  73. 执行原始SQL语句
  74. Args:
  75. sql_statement: 要执行的SQL语句
  76. Returns:
  77. 执行结果
  78. """
  79. try:
  80. with engine.begin() as connection:
  81. result = connection.execute(sql_statement)
  82. return result
  83. except SQLAlchemyError as e:
  84. logger.error(f"执行SQL语句失败: {str(e)}")
  85. raise
  86. # 新增:自动创建数据库表(关键!)
  87. def create_tables():
  88. try:
  89. # ✨ 新增:同时启用 PostGIS 核心和 Raster 扩展
  90. with engine.begin() as conn:
  91. # 启用 PostGIS 核心(矢量)
  92. conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis;")) # pyright: ignore[reportUndefinedVariable]
  93. # 启用 PostGIS Raster(栅格)
  94. conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis_raster;")) # pyright: ignore[reportUndefinedVariable]
  95. logger.info("PostGIS 核心及 Raster 扩展已启用(或已存在)")
  96. # 必须导入所有模型,否则 Base 不知道要创建哪些表
  97. # 替换成你项目中实际的模型文件路径(根据你的目录结构调整)
  98. from app.models.orm_models import Base # 确保模型继承自这个 Base
  99. from app.models.vector import VectorData # 导入需要创建的表
  100. #from app.models.raster import RasterData # 补充其他模型
  101. from app.models.county import County
  102. from app.models.farmland import FarmlandData
  103. from app.models.agricultural import AgriculturalData # 假设类名是 Agricultural
  104. from app.models.assessment import Assessment
  105. from app.models.atmo_company import AtmoCompany
  106. from app.models.atmo_sample import AtmoSampleData
  107. from app.models.CropCd_input import CropCdInputData
  108. from app.models.CropCd_output import CropCdOutputData
  109. from app.models.cross_section import CrossSection
  110. from app.models.EffCd_input import EffCdInputData
  111. from app.models.EffCd_output import EffCdOutputData
  112. from app.models.FluxCd_input import FluxCdInputData
  113. from app.models.FluxCd_output import FluxCdOutputData
  114. from app.models.MSM_input import MSMInputData
  115. from app.models.MSM_output import MSMOutputData
  116. from app.models.parameters import Parameters
  117. from app.models.soil import SoilData
  118. from app.models.water_sample import WaterSampleData
  119. # 创建所有表
  120. Base.metadata.create_all(bind=engine)
  121. logger.info("数据库表自动创建成功!")
  122. except ImportError as e:
  123. logger.warning(f"未找到模型文件,可能需要手动创建表:{str(e)}")
  124. except Exception as e:
  125. logger.error(f"创建表失败:{str(e)}")
  126. raise
  127. # 执行建表(在连接测试成功后)
  128. create_tables()