123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- from fastapi import FastAPI
- from .api import vector, raster, cd_prediction
- from .database import engine, Base
- from fastapi.middleware.cors import CORSMiddleware
- import logging
- import sys
- from alembic.config import Config
- from alembic import command
- from alembic.runtime.migration import MigrationContext
- from alembic.script import ScriptDirectory
- import os
- # 设置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- def check_and_upgrade_database():
- """
- 检查数据库迁移状态并自动升级到最新版本
-
- @description: 在应用启动前检查数据库版本,如果需要升级则自动执行
- @returns: None
- @throws: SystemExit 当数据库操作失败时退出程序
- """
- try:
- # 配置 Alembic
- alembic_cfg = Config(os.path.join(os.path.dirname(os.path.dirname(__file__)), "alembic.ini"))
-
- # 获取当前数据库版本
- with engine.connect() as connection:
- context = MigrationContext.configure(connection)
- current_rev = context.get_current_revision()
-
- # 获取脚本目录和最新版本
- script_dir = ScriptDirectory.from_config(alembic_cfg)
- head_rev = script_dir.get_current_head()
-
- logger.info(f"当前数据库版本: {current_rev}")
- logger.info(f"最新迁移版本: {head_rev}")
-
- # 检查是否需要升级
- if current_rev != head_rev:
- logger.warning("数据库版本不是最新版本,正在自动升级...")
-
- # 执行升级
- command.upgrade(alembic_cfg, "head")
- logger.info("数据库升级成功")
- else:
- logger.info("数据库版本已是最新")
-
- except Exception as e:
- logger.error(f"数据库迁移检查失败: {str(e)}")
- logger.error("程序将退出,请手动检查数据库状态")
- sys.exit(1)
- def safe_create_tables():
- """
- 安全地创建数据库表
-
- @description: 在确保迁移状态正确后创建表结构
- """
- try:
- # 先检查和升级数据库
- check_and_upgrade_database()
-
- # 创建数据库表(如果迁移已正确应用,这里应该不会有冲突)
- Base.metadata.create_all(bind=engine)
- logger.info("数据库表结构检查完成")
-
- except Exception as e:
- logger.error(f"数据库表创建失败: {str(e)}")
- logger.error("请检查数据库连接和迁移状态")
- sys.exit(1)
- # 执行数据库初始化
- safe_create_tables()
- app = FastAPI(
- title="地图数据处理系统",
- description="一个用于处理地图数据的API系统",
- version="1.0.0",
- openapi_tags=[
- {
- "name": "vector",
- "description": "矢量数据相关接口",
- },
- {
- "name": "raster",
- "description": "栅格数据相关接口",
- },
- {
- "name": "cd-prediction",
- "description": "Cd预测模型相关接口",
- }
- ]
- )
- # ---------------------------
- # 添加 CORS 配置(关键修改)
- # ---------------------------
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["https://soilgd.com"], # 允许的前端域名(需与前端实际域名一致)
- allow_methods=["*"], # 允许的 HTTP 方法(GET/POST/PUT/DELETE等)
- allow_headers=["*"], # 允许的请求头
- allow_credentials=True, # 允许携带 Cookie(如需)
- )
- # 注册路由
- app.include_router(vector.router, prefix="/api/vector", tags=["vector"])
- app.include_router(raster.router, prefix="/api/raster", tags=["raster"])
- app.include_router(cd_prediction.router, prefix="/api/cd-prediction", tags=["cd-prediction"])
- @app.get("/")
- async def root():
- return {"message": "Welcome to the GIS Data Management API"}
- # if __name__ == "__main__":
- # import uvicorn
- # uvicorn.run(app, host="0.0.0.0", port=8000)
|