main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from fastapi import FastAPI
  2. from .api import vector, raster, cd_prediction
  3. from .database import engine, Base
  4. from fastapi.middleware.cors import CORSMiddleware
  5. import logging
  6. import sys
  7. from alembic.config import Config
  8. from alembic import command
  9. from alembic.runtime.migration import MigrationContext
  10. from alembic.script import ScriptDirectory
  11. import os
  12. # 设置日志
  13. logging.basicConfig(level=logging.INFO)
  14. logger = logging.getLogger(__name__)
  15. def check_and_upgrade_database():
  16. """
  17. 检查数据库迁移状态并自动升级到最新版本
  18. @description: 在应用启动前检查数据库版本,如果需要升级则自动执行
  19. @returns: None
  20. @throws: SystemExit 当数据库操作失败时退出程序
  21. """
  22. try:
  23. # 配置 Alembic
  24. alembic_cfg = Config(os.path.join(os.path.dirname(os.path.dirname(__file__)), "alembic.ini"))
  25. # 获取当前数据库版本
  26. with engine.connect() as connection:
  27. context = MigrationContext.configure(connection)
  28. current_rev = context.get_current_revision()
  29. # 获取脚本目录和最新版本
  30. script_dir = ScriptDirectory.from_config(alembic_cfg)
  31. head_rev = script_dir.get_current_head()
  32. logger.info(f"当前数据库版本: {current_rev}")
  33. logger.info(f"最新迁移版本: {head_rev}")
  34. # 检查是否需要升级
  35. if current_rev != head_rev:
  36. logger.warning("数据库版本不是最新版本,正在自动升级...")
  37. # 执行升级
  38. command.upgrade(alembic_cfg, "head")
  39. logger.info("数据库升级成功")
  40. else:
  41. logger.info("数据库版本已是最新")
  42. except Exception as e:
  43. logger.error(f"数据库迁移检查失败: {str(e)}")
  44. logger.error("程序将退出,请手动检查数据库状态")
  45. sys.exit(1)
  46. def safe_create_tables():
  47. """
  48. 安全地创建数据库表
  49. @description: 在确保迁移状态正确后创建表结构
  50. """
  51. try:
  52. # 先检查和升级数据库
  53. check_and_upgrade_database()
  54. # 创建数据库表(如果迁移已正确应用,这里应该不会有冲突)
  55. Base.metadata.create_all(bind=engine)
  56. logger.info("数据库表结构检查完成")
  57. except Exception as e:
  58. logger.error(f"数据库表创建失败: {str(e)}")
  59. logger.error("请检查数据库连接和迁移状态")
  60. sys.exit(1)
  61. # 执行数据库初始化
  62. safe_create_tables()
  63. app = FastAPI(
  64. title="地图数据处理系统",
  65. description="一个用于处理地图数据的API系统",
  66. version="1.0.0",
  67. openapi_tags=[
  68. {
  69. "name": "vector",
  70. "description": "矢量数据相关接口",
  71. },
  72. {
  73. "name": "raster",
  74. "description": "栅格数据相关接口",
  75. },
  76. {
  77. "name": "cd-prediction",
  78. "description": "Cd预测模型相关接口",
  79. }
  80. ]
  81. )
  82. # ---------------------------
  83. # 添加 CORS 配置(关键修改)
  84. # ---------------------------
  85. app.add_middleware(
  86. CORSMiddleware,
  87. allow_origins=["https://soilgd.com"], # 允许的前端域名(需与前端实际域名一致)
  88. allow_methods=["*"], # 允许的 HTTP 方法(GET/POST/PUT/DELETE等)
  89. allow_headers=["*"], # 允许的请求头
  90. allow_credentials=True, # 允许携带 Cookie(如需)
  91. )
  92. # 注册路由
  93. app.include_router(vector.router, prefix="/api/vector", tags=["vector"])
  94. app.include_router(raster.router, prefix="/api/raster", tags=["raster"])
  95. app.include_router(cd_prediction.router, prefix="/api/cd-prediction", tags=["cd-prediction"])
  96. @app.get("/")
  97. async def root():
  98. return {"message": "Welcome to the GIS Data Management API"}
  99. # if __name__ == "__main__":
  100. # import uvicorn
  101. # uvicorn.run(app, host="0.0.0.0", port=8000)