Explorar o código

实现数据库迁移检查逻辑

drggboy hai 1 mes
pai
achega
eded3b38ea
Modificáronse 2 ficheiros con 329 adicións e 2 borrados
  1. 71 2
      app/main.py
  2. 258 0
      scripts/db_health_check.py

+ 71 - 2
app/main.py

@@ -2,9 +2,78 @@ from fastapi import FastAPI
 from .api import vector, raster, cd_prediction
 from .database import engine, Base
 from fastapi.middleware.cors import CORSMiddleware
+import logging
+import sys
+from alembic.config import Config
+from alembic import command
+from alembic.runtime.migration import MigrationContext
+from alembic.script import ScriptDirectory
+import os
 
-# 创建数据库表
-Base.metadata.create_all(bind=engine)
+# 设置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def check_and_upgrade_database():
+    """
+    检查数据库迁移状态并自动升级到最新版本
+    
+    @description: 在应用启动前检查数据库版本,如果需要升级则自动执行
+    @returns: None
+    @throws: SystemExit 当数据库操作失败时退出程序
+    """
+    try:
+        # 配置 Alembic
+        alembic_cfg = Config(os.path.join(os.path.dirname(os.path.dirname(__file__)), "alembic.ini"))
+        
+        # 获取当前数据库版本
+        with engine.connect() as connection:
+            context = MigrationContext.configure(connection)
+            current_rev = context.get_current_revision()
+            
+        # 获取脚本目录和最新版本
+        script_dir = ScriptDirectory.from_config(alembic_cfg)
+        head_rev = script_dir.get_current_head()
+        
+        logger.info(f"当前数据库版本: {current_rev}")
+        logger.info(f"最新迁移版本: {head_rev}")
+        
+        # 检查是否需要升级
+        if current_rev != head_rev:
+            logger.warning("数据库版本不是最新版本,正在自动升级...")
+            
+            # 执行升级
+            command.upgrade(alembic_cfg, "head")
+            logger.info("数据库升级成功")
+        else:
+            logger.info("数据库版本已是最新")
+            
+    except Exception as e:
+        logger.error(f"数据库迁移检查失败: {str(e)}")
+        logger.error("程序将退出,请手动检查数据库状态")
+        sys.exit(1)
+
+def safe_create_tables():
+    """
+    安全地创建数据库表
+    
+    @description: 在确保迁移状态正确后创建表结构
+    """
+    try:
+        # 先检查和升级数据库
+        check_and_upgrade_database()
+        
+        # 创建数据库表(如果迁移已正确应用,这里应该不会有冲突)
+        Base.metadata.create_all(bind=engine)
+        logger.info("数据库表结构检查完成")
+        
+    except Exception as e:
+        logger.error(f"数据库表创建失败: {str(e)}")
+        logger.error("请检查数据库连接和迁移状态")
+        sys.exit(1)
+
+# 执行数据库初始化
+safe_create_tables()
 
 app = FastAPI(
     title="地图数据处理系统",

+ 258 - 0
scripts/db_health_check.py

@@ -0,0 +1,258 @@
+"""
+数据库健康检查工具
+@description: 检查数据库连接状态、迁移状态和表结构完整性
+@author: AcidMap Team
+@version: 1.0.0
+"""
+
+import os
+import sys
+import logging
+from datetime import datetime
+
+# 添加项目根目录到Python路径
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from app.database import engine, Base
+from alembic.config import Config
+from alembic.runtime.migration import MigrationContext
+from alembic.script import ScriptDirectory
+from sqlalchemy import text
+import traceback
+
+# 设置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+class DatabaseHealthChecker:
+    """
+    数据库健康检查器
+    
+    @description: 全面检查数据库状态,包括连接、迁移、表结构等
+    """
+    
+    def __init__(self):
+        self.engine = engine
+        self.alembic_cfg = Config(os.path.join(os.path.dirname(os.path.dirname(__file__)), "alembic.ini"))
+        self.results = {
+            "timestamp": datetime.now().isoformat(),
+            "connection": False,
+            "migration_status": "unknown",
+            "tables_exist": False,
+            "issues": [],
+            "recommendations": []
+        }
+    
+    def check_database_connection(self):
+        """
+        检查数据库连接
+        
+        @returns: bool 连接是否成功
+        """
+        try:
+            with self.engine.connect() as conn:
+                # 执行简单查询测试连接
+                result = conn.execute(text("SELECT 1"))
+                result.fetchone()
+                
+            self.results["connection"] = True
+            logger.info("✓ 数据库连接正常")
+            return True
+            
+        except Exception as e:
+            self.results["connection"] = False
+            self.results["issues"].append(f"数据库连接失败: {str(e)}")
+            logger.error(f"✗ 数据库连接失败: {str(e)}")
+            return False
+    
+    def check_migration_status(self):
+        """
+        检查数据库迁移状态
+        
+        @returns: dict 迁移状态信息
+        """
+        try:
+            with self.engine.connect() as connection:
+                context = MigrationContext.configure(connection)
+                current_rev = context.get_current_revision()
+                
+            script_dir = ScriptDirectory.from_config(self.alembic_cfg)
+            head_rev = script_dir.get_current_head()
+            
+            migration_info = {
+                "current_revision": current_rev,
+                "head_revision": head_rev,
+                "is_up_to_date": current_rev == head_rev
+            }
+            
+            if migration_info["is_up_to_date"]:
+                self.results["migration_status"] = "up_to_date"
+                logger.info(f"✓ 数据库迁移状态正常 (版本: {current_rev})")
+            else:
+                self.results["migration_status"] = "outdated"
+                self.results["issues"].append(f"数据库版本过时: 当前 {current_rev}, 最新 {head_rev}")
+                self.results["recommendations"].append("执行 'python db_migrate.py upgrade' 升级数据库")
+                logger.warning(f"⚠ 数据库需要升级: {current_rev} -> {head_rev}")
+            
+            return migration_info
+            
+        except Exception as e:
+            self.results["migration_status"] = "error"
+            self.results["issues"].append(f"迁移状态检查失败: {str(e)}")
+            logger.error(f"✗ 迁移状态检查失败: {str(e)}")
+            return {"error": str(e)}
+    
+    def check_table_structure(self):
+        """
+        检查表结构完整性
+        
+        @returns: dict 表结构检查结果
+        """
+        try:
+            with self.engine.connect() as conn:
+                # 获取所有表名
+                tables_query = text("""
+                    SELECT table_name 
+                    FROM information_schema.tables 
+                    WHERE table_schema = 'public'
+                """)
+                existing_tables = [row[0] for row in conn.execute(tables_query).fetchall()]
+                
+                # 获取模型定义的表名
+                model_tables = [table.name for table in Base.metadata.tables.values()]
+                
+                # 检查缺失的表
+                missing_tables = set(model_tables) - set(existing_tables)
+                extra_tables = set(existing_tables) - set(model_tables) - {'alembic_version'}
+                
+                table_info = {
+                    "existing_tables": existing_tables,
+                    "model_tables": model_tables,
+                    "missing_tables": list(missing_tables),
+                    "extra_tables": list(extra_tables)
+                }
+                
+                if missing_tables:
+                    self.results["tables_exist"] = False
+                    self.results["issues"].append(f"缺失表: {', '.join(missing_tables)}")
+                    self.results["recommendations"].append("执行数据库升级或重新创建表结构")
+                    logger.warning(f"⚠ 缺失表: {', '.join(missing_tables)}")
+                else:
+                    self.results["tables_exist"] = True
+                    logger.info("✓ 所有必需的表都存在")
+                
+                if extra_tables:
+                    logger.info(f"发现额外表: {', '.join(extra_tables)}")
+                
+                return table_info
+                
+        except Exception as e:
+            self.results["tables_exist"] = False
+            self.results["issues"].append(f"表结构检查失败: {str(e)}")
+            logger.error(f"✗ 表结构检查失败: {str(e)}")
+            return {"error": str(e)}
+    
+    def check_spatial_extensions(self):
+        """
+        检查空间扩展 (PostGIS)
+        
+        @returns: dict 空间扩展状态
+        """
+        try:
+            with self.engine.connect() as conn:
+                # 检查PostGIS扩展
+                postgis_query = text("""
+                    SELECT extname, extversion 
+                    FROM pg_extension 
+                    WHERE extname = 'postgis'
+                """)
+                postgis_result = conn.execute(postgis_query).fetchall()
+                
+                if postgis_result:
+                    version = postgis_result[0][1]
+                    logger.info(f"✓ PostGIS 扩展已安装 (版本: {version})")
+                    return {"installed": True, "version": version}
+                else:
+                    self.results["issues"].append("PostGIS 扩展未安装")
+                    self.results["recommendations"].append("安装 PostGIS 扩展: CREATE EXTENSION postgis;")
+                    logger.warning("⚠ PostGIS 扩展未安装")
+                    return {"installed": False}
+                    
+        except Exception as e:
+            self.results["issues"].append(f"空间扩展检查失败: {str(e)}")
+            logger.error(f"✗ 空间扩展检查失败: {str(e)}")
+            return {"error": str(e)}
+    
+    def run_full_check(self):
+        """
+        执行完整的健康检查
+        
+        @returns: dict 完整的检查结果
+        """
+        logger.info("开始数据库健康检查...")
+        logger.info("=" * 50)
+        
+        # 检查数据库连接
+        if not self.check_database_connection():
+            logger.error("数据库连接失败,跳过后续检查")
+            return self.results
+        
+        # 检查迁移状态
+        migration_info = self.check_migration_status()
+        
+        # 检查表结构
+        table_info = self.check_table_structure()
+        
+        # 检查空间扩展
+        spatial_info = self.check_spatial_extensions()
+        
+        # 汇总结果
+        self.results.update({
+            "migration_info": migration_info,
+            "table_info": table_info,
+            "spatial_info": spatial_info
+        })
+        
+        logger.info("=" * 50)
+        logger.info("健康检查完成")
+        
+        # 打印总结
+        if self.results["issues"]:
+            logger.warning(f"发现 {len(self.results['issues'])} 个问题:")
+            for issue in self.results["issues"]:
+                logger.warning(f"  - {issue}")
+        
+        if self.results["recommendations"]:
+            logger.info(f"建议执行 {len(self.results['recommendations'])} 项操作:")
+            for rec in self.results["recommendations"]:
+                logger.info(f"  - {rec}")
+        
+        if not self.results["issues"]:
+            logger.info("✓ 数据库状态良好")
+        
+        return self.results
+
+def main():
+    """
+    主函数
+    """
+    try:
+        checker = DatabaseHealthChecker()
+        results = checker.run_full_check()
+        
+        # 根据结果设置退出码
+        if results["issues"]:
+            sys.exit(1)  # 有问题时退出码为1
+        else:
+            sys.exit(0)  # 正常时退出码为0
+            
+    except Exception as e:
+        logger.error(f"健康检查执行失败: {str(e)}")
+        logger.error(f"错误详情: {traceback.format_exc()}")
+        sys.exit(2)  # 检查失败时退出码为2
+
+if __name__ == "__main__":
+    main()