drggboy 1 month ago
commit
ea8ff8a3c1

+ 5 - 0
.env

@@ -0,0 +1,5 @@
+DB_HOST=localhost
+DB_PORT=5432
+DB_NAME=testdb
+DB_USER=postgres
+DB_PASSWORD=root

+ 14 - 0
.gitignore

@@ -0,0 +1,14 @@
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.env
+.venv
+env/
+venv/
+ENV/
+myenv/
+*.db
+.idea/
+.vscode/
+*.log

+ 91 - 0
README.md

@@ -0,0 +1,91 @@
+# 地图数据处理系统
+
+这是一个基于 FastAPI 开发的地图数据处理系统,支持栅格和矢量数据的处理、存储和管理。
+
+## 功能特点
+
+- 支持栅格数据的导入和导出
+- 支持矢量数据的导入和导出
+- 使用 PostgreSQL + PostGIS 存储空间数据
+- 提供 RESTful API 接口
+- 支持空间数据查询和分析
+
+## 系统架构
+
+```
+app/
+├── api/            # API 路由层
+│   ├── raster.py   # 栅格数据接口
+│   └── vector.py   # 矢量数据接口
+├── services/       # 业务逻辑层
+│   ├── raster_service.py
+│   └── vector_service.py
+├── models/         # 数据模型
+├── utils/          # 工具函数
+├── database.py     # 数据库配置
+└── main.py         # 主程序入口
+```
+
+## 安装依赖
+
+1. 确保已安装 Python 3.8+
+2. 安装 PostgreSQL 数据库并启用 PostGIS 扩展
+3. 安装项目依赖:
+```bash
+pip install -r requirements.txt
+```
+
+## 配置
+
+1. 复制 `.env.example` 文件为 `.env`
+2. 修改 `.env` 文件中的数据库连接信息:
+```
+DB_HOST=localhost
+DB_PORT=5432
+DB_NAME=your_database
+DB_USER=your_username
+DB_PASSWORD=your_password
+```
+
+## 运行
+
+1. 启动服务:
+```bash
+uvicorn app.main:app --reload
+```
+
+2. 访问 API 文档:
+- Swagger UI: http://localhost:8000/docs
+- ReDoc: http://localhost:8000/redoc
+
+## API 接口
+
+### 栅格数据接口 (/api/raster)
+- POST /import: 导入栅格数据
+- GET /export: 导出栅格数据
+- GET /query: 查询栅格数据
+
+### 矢量数据接口 (/api/vector)
+- POST /import: 导入矢量数据
+- GET /export: 导出矢量数据
+- GET /query: 查询矢量数据
+
+## 开发环境
+
+- Python 3.8+
+- PostgreSQL 12+
+- PostGIS 3.0+
+- FastAPI 0.104.1
+- Uvicorn 0.24.0
+
+## 贡献指南
+
+1. Fork 项目
+2. 创建特性分支
+3. 提交更改
+4. 推送到分支
+5. 创建 Pull Request
+
+## 许可证
+
+MIT License 

+ 3 - 0
app/__init__.py

@@ -0,0 +1,3 @@
+"""
+地图数据处理系统
+""" 

+ 67 - 0
app/api/raster.py

@@ -0,0 +1,67 @@
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
+from fastapi.responses import FileResponse
+from sqlalchemy.orm import Session
+from ..database import get_db
+from ..services import raster_service
+import os
+import shutil
+from fastapi import BackgroundTasks
+from typing import List
+from pydantic import BaseModel
+
+router = APIRouter()
+
+class RasterBatchExportRequest(BaseModel):
+    raster_ids: List[int]
+
+@router.get("/{raster_id}", summary="获取栅格数据", description="获取指定ID的栅格数据信息")
+async def get_raster(raster_id: int, db: Session = Depends(get_db)):
+    """获取指定ID的栅格数据信息"""
+    return raster_service.get_raster_data(db, raster_id)
+
+@router.post("/import", summary="导入栅格数据", description="将TIFF格式的栅格数据导入到数据库中")
+async def import_raster_data(file: UploadFile = File(...), db: Session = Depends(get_db)):
+    """导入栅格数据到数据库"""
+    # 检查文件类型
+    if not file.filename.endswith('.tif'):
+        raise HTTPException(status_code=400, detail="只支持TIFF格式的栅格数据")
+    
+    return await raster_service.import_raster_data(file, db)
+
+@router.get("/{raster_id}/export", summary="导出栅格数据", description="将指定ID的栅格数据导出为TIFF文件")
+async def export_raster_data(raster_id: int, db: Session = Depends(get_db)):
+    """导出指定ID的栅格数据为TIFF文件"""
+    result = raster_service.export_raster_data(db, raster_id)
+    
+    # 检查文件是否存在
+    if not os.path.exists(result["file_path"]):
+        if "temp_dir" in result and os.path.exists(result["temp_dir"]):
+            shutil.rmtree(result["temp_dir"])
+        raise HTTPException(status_code=404, detail="导出文件不存在")
+    
+    # 返回文件下载
+    return FileResponse(
+        path=result["file_path"],
+        filename=f"raster_{raster_id}.tif",
+        media_type="image/tiff",
+        background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
+    )
+
+@router.post("/export/batch", summary="批量导出栅格数据", description="将多个ID的栅格数据批量导出为TIFF文件")
+async def export_raster_data_batch(request: RasterBatchExportRequest, db: Session = Depends(get_db)):
+    """批量导出栅格数据为TIFF文件"""
+    result = raster_service.export_raster_data_batch(db, request.raster_ids)
+    
+    # 检查文件是否存在
+    if not os.path.exists(result["file_path"]):
+        if "temp_dir" in result and os.path.exists(result["temp_dir"]):
+            shutil.rmtree(result["temp_dir"])
+        raise HTTPException(status_code=404, detail="导出文件不存在")
+    
+    # 返回文件下载
+    return FileResponse(
+        path=result["file_path"],
+        filename=os.path.basename(result["file_path"]),
+        media_type="application/zip",
+        background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
+    )

+ 63 - 0
app/api/vector.py

@@ -0,0 +1,63 @@
+# 矢量数据API
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks
+from fastapi.responses import FileResponse
+from sqlalchemy.orm import Session
+from ..database import get_db
+from ..services import vector_service
+import os
+import shutil
+from typing import List
+
+router = APIRouter()
+
+@router.post("/import", summary="导入GeoJSON文件", description="将GeoJSON文件导入到数据库中")
+async def import_vector_data(file: UploadFile = File(...), db: Session = Depends(get_db)):
+    """导入GeoJSON文件到数据库"""
+    # 检查文件类型
+    if not file.filename.endswith('.geojson'):
+        raise HTTPException(status_code=400, detail="只支持GeoJSON文件")
+    
+    return await vector_service.import_vector_data(file, db)
+
+@router.get("/{vector_id}", summary="获取矢量数据", description="根据ID获取一条矢量数据记录")
+def get_vector_data(vector_id: int, db: Session = Depends(get_db)):
+    """获取指定ID的矢量数据"""
+    return vector_service.get_vector_data(db, vector_id)
+
+@router.get("/{vector_id}/export", summary="导出矢量数据", description="将指定ID的矢量数据导出为GeoJSON文件")
+async def export_vector_data(vector_id: int, db: Session = Depends(get_db)):
+    """导出指定ID的矢量数据为GeoJSON文件"""
+    result = vector_service.export_vector_data(db, vector_id)
+    
+    # 检查文件是否存在
+    if not os.path.exists(result["file_path"]):
+        if "temp_dir" in result and os.path.exists(result["temp_dir"]):
+            shutil.rmtree(result["temp_dir"])
+        raise HTTPException(status_code=404, detail="导出文件不存在")
+    
+    # 返回文件下载
+    return FileResponse(
+        path=result["file_path"],
+        filename=os.path.basename(result["file_path"]),
+        media_type="application/json",
+        background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
+    )
+
+@router.post("/export/batch", summary="批量导出矢量数据", description="将多个ID的矢量数据批量导出为GeoJSON文件")
+async def export_vector_data_batch(vector_ids: List[int], db: Session = Depends(get_db)):
+    """批量导出矢量数据为GeoJSON文件"""
+    result = vector_service.export_vector_data_batch(db, vector_ids)
+    
+    # 检查文件是否存在
+    if not os.path.exists(result["file_path"]):
+        if "temp_dir" in result and os.path.exists(result["temp_dir"]):
+            shutil.rmtree(result["temp_dir"])
+        raise HTTPException(status_code=404, detail="导出文件不存在")
+    
+    # 返回文件下载
+    return FileResponse(
+        path=result["file_path"],
+        filename=os.path.basename(result["file_path"]),
+        media_type="application/json",
+        background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
+    )

+ 73 - 0
app/database.py

@@ -0,0 +1,73 @@
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.ext.declarative import declarative_base
+import os
+from dotenv import load_dotenv
+import logging
+from sqlalchemy.exc import SQLAlchemyError
+
+# 配置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# 创建Base类
+Base = declarative_base()
+
+# 加载环境变量
+load_dotenv()
+
+# 从环境变量获取数据库连接信息
+DB_USER = os.getenv("DB_USER", "postgres")
+DB_PASSWORD = os.getenv("DB_PASSWORD", "root")
+DB_HOST = os.getenv("DB_HOST", "localhost")
+DB_PORT = os.getenv("DB_PORT", "5432")
+DB_NAME = os.getenv("DB_NAME", "testdb")
+
+# 构建数据库连接URL
+SQLALCHEMY_DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
+
+def create_database_engine():
+    """创建并配置数据库引擎"""
+    try:
+        engine = create_engine(
+            SQLALCHEMY_DATABASE_URL,
+            pool_size=5,
+            max_overflow=10,
+            pool_timeout=30,
+            pool_recycle=1800
+        )
+        return engine
+    except Exception as e:
+        logger.error(f"创建数据库引擎失败: {str(e)}")
+        raise
+
+def test_database_connection(engine):
+    """测试数据库连接"""
+    try:
+        with engine.connect() as conn:
+            logger.info("数据库连接测试成功")
+            return True
+    except SQLAlchemyError as e:
+        logger.error(f"数据库连接测试失败: {str(e)}")
+        return False
+
+# 创建数据库引擎
+engine = create_database_engine()
+
+# 测试数据库连接
+if not test_database_connection(engine):
+    raise Exception("无法连接到数据库,请检查数据库配置")
+
+# 创建会话工厂
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+def get_db():
+    """获取数据库会话"""
+    db = SessionLocal()
+    try:
+        yield db
+    except SQLAlchemyError as e:
+        logger.error(f"数据库操作错误: {str(e)}")
+        raise
+    finally:
+        db.close()

+ 0 - 0
app/logs/vector_service.log


+ 34 - 0
app/main.py

@@ -0,0 +1,34 @@
+from fastapi import FastAPI
+from .api import vector, raster
+from .database import engine, Base
+
+# 创建数据库表
+Base.metadata.create_all(bind=engine)
+
+app = FastAPI(
+    title="地图数据处理系统",
+    description="一个用于处理地图数据的API系统",
+    version="1.0.0",
+    openapi_tags=[
+        {
+            "name": "vector",
+            "description": "矢量数据相关接口",
+        },
+        {
+            "name": "raster",
+            "description": "栅格数据相关接口",
+        }
+    ]
+)
+
+# 注册路由
+app.include_router(vector.router, prefix="/api/vector", tags=["vector"])
+app.include_router(raster.router, prefix="/api/raster", tags=["raster"])
+
+@app.get("/")
+async def root():
+    return {"message": "Welcome to the GIS Data Management API"}
+
+# if __name__ == "__main__":
+#     import uvicorn
+#     uvicorn.run(app, host="0.0.0.0", port=8000)

+ 4 - 0
app/models/base.py

@@ -0,0 +1,4 @@
+# 基础模型
+
+from sqlalchemy.ext.declarative import declarative_base
+Base = declarative_base()

+ 11 - 0
app/models/raster.py

@@ -0,0 +1,11 @@
+# 栅格数据模型
+
+from sqlalchemy import Column, Integer
+from geoalchemy2 import Raster
+from ..database import Base
+
+class RasterData(Base):
+    __tablename__ = "raster_table"
+    
+    id = Column(Integer, primary_key=True, index=True)
+    rast = Column(Raster)  # 使用PostGIS的栅格类型

+ 16 - 0
app/models/vector.py

@@ -0,0 +1,16 @@
+# 矢量数据模型
+
+from sqlalchemy import Column, Integer, String, JSON, Table, MetaData
+from .base import Base
+from ..database import engine
+
+# 使用现有的表
+metadata = MetaData()
+surveydata = Table(
+    "surveydata",
+    metadata,
+    autoload_with=engine
+)
+
+class VectorData(Base):
+    __table__ = surveydata

+ 254 - 0
app/services/raster_service.py

@@ -0,0 +1,254 @@
+from fastapi import HTTPException, UploadFile
+from sqlalchemy.orm import Session
+from ..models.raster import RasterData
+import os
+from datetime import datetime
+import subprocess
+import tempfile
+import shutil
+from sqlalchemy import text
+import rasterio
+from rasterio.io import MemoryFile
+import numpy as np
+import zipfile
+from typing import List
+
+def get_raster_data(db: Session, raster_id: int):
+    """通过ID获取一条栅格数据记录"""
+    # 获取栅格数据
+    query = text("""
+        SELECT id, rast
+        FROM raster_table 
+        WHERE id = :raster_id
+    """)
+    
+    result = db.execute(query, {"raster_id": raster_id}).first()
+    
+    if not result:
+        raise HTTPException(status_code=404, detail="栅格数据不存在")
+    
+    # 将Row对象转换为字典
+    return dict(result._mapping)
+
+async def import_raster_data(file: UploadFile, db: Session) -> dict:
+    """导入栅格数据到数据库"""
+    try:
+        # 创建临时目录
+        temp_dir = tempfile.mkdtemp()
+        temp_file_path = os.path.join(temp_dir, file.filename)
+        
+        try:
+            # 保存上传的文件
+            with open(temp_file_path, "wb") as buffer:
+                content = await file.read()
+                buffer.write(content)
+            
+            # 检查文件是否存在且不为空
+            if not os.path.exists(temp_file_path):
+                raise Exception("临时文件创建失败")
+            if os.path.getsize(temp_file_path) == 0:
+                raise Exception("上传的文件为空")
+            
+            # 使用raster2pgsql命令导入数据
+            cmd = [
+                'raster2pgsql',
+                '-s', '4490',  # 空间参考系统
+                '-I',  # 创建空间索引
+                '-M',  # 创建空间索引
+                '-a',  # 追加模式
+                temp_file_path,
+                'raster_table'
+            ]
+            
+            # 执行命令
+            process = subprocess.Popen(
+                cmd,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
+            )
+            stdout, stderr = process.communicate()
+            
+            if process.returncode != 0:
+                error_msg = stderr.decode()
+                raise Exception(f"raster2pgsql命令执行失败: {error_msg}")
+            
+            # 检查导入的SQL是否为空
+            if not stdout:
+                raise Exception("raster2pgsql没有生成任何SQL语句")
+            
+            # 执行生成的SQL
+            sql_commands = stdout.decode().split(';')
+            for sql in sql_commands:
+                if sql.strip():
+                    db.execute(text(sql))
+            db.commit()
+            
+            # 获取最后插入的记录的ID
+            result = db.execute(text("""
+                SELECT id, ST_IsEmpty(rast) as is_empty, ST_Width(rast) as width
+                FROM raster_table 
+                ORDER BY id DESC 
+                LIMIT 1
+            """)).first()
+            
+            if not result:
+                raise Exception("无法获取导入的栅格数据ID")
+            
+            if result.is_empty or result.width is None:
+                raise Exception("导入的栅格数据为空或无效")
+            
+            return {
+                "message": "栅格数据导入成功",
+                "raster_id": result.id,
+                "file_path": temp_file_path
+            }
+            
+        finally:
+            # 清理临时目录
+            shutil.rmtree(temp_dir)
+            
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
+
+def export_raster_data(db: Session, raster_id: int):
+    """导出指定ID的栅格数据为TIFF文件"""
+    try:
+        # 创建临时目录
+        temp_dir = tempfile.mkdtemp()
+        temp_file_path = os.path.join(temp_dir, f"raster_{raster_id}.tif")
+        
+        # 从数据库获取栅格数据
+        query = text("""
+            SELECT 
+                ST_AsBinary(rast) as raster_data,
+                ST_Width(rast) as width,
+                ST_Height(rast) as height,
+                ST_NumBands(rast) as num_bands,
+                ST_UpperLeftX(rast) as upper_left_x,
+                ST_UpperLeftY(rast) as upper_left_y,
+                ST_ScaleX(rast) as scale_x,
+                ST_ScaleY(rast) as scale_y,
+                ST_SkewX(rast) as skew_x,
+                ST_SkewY(rast) as skew_y,
+                ST_SRID(rast) as srid,
+                ST_BandPixelType(rast, 1) as pixel_type
+            FROM raster_table 
+            WHERE id = :raster_id
+        """)
+        
+        result = db.execute(query, {"raster_id": raster_id}).first()
+        
+        if not result:
+            shutil.rmtree(temp_dir)
+            raise HTTPException(status_code=404, detail="栅格数据不存在")
+        
+        # 根据像素类型选择合适的数据类型
+        dtype_map = {
+            '8BUI': np.uint8,
+            '16BUI': np.uint16,
+            '32BUI': np.uint32,
+            '8BSI': np.int8,
+            '16BSI': np.int16,
+            '32BSI': np.int32,
+            '32BF': np.float32,
+            '64BF': np.float64
+        }
+        
+        dtype = dtype_map.get(result.pixel_type, np.float32)
+        
+        # 计算预期的数据大小
+        expected_size = result.width * result.height * result.num_bands * np.dtype(dtype).itemsize
+        
+        # 检查二进制数据大小
+        if len(result.raster_data) < expected_size:
+            shutil.rmtree(temp_dir)
+            raise Exception(f"数据大小不足: 预期至少 {expected_size} 字节,实际 {len(result.raster_data)} 字节")
+        
+        # 跳过头部信息,只取实际数据部分
+        # PostGIS 二进制格式的头部通常是66字节
+        header_size = 66
+        actual_data = result.raster_data[header_size:header_size + expected_size]
+        
+        # 将二进制数据转换为numpy数组
+        raster_data = np.frombuffer(actual_data, dtype=dtype)
+        # 重塑数组为正确的形状 (bands, height, width)
+        raster_data = raster_data.reshape((result.num_bands, result.height, result.width))
+        
+        # 创建内存文件
+        with MemoryFile() as memfile:
+            # 创建新的栅格数据集
+            with memfile.open(
+                driver='GTiff',
+                height=result.height,
+                width=result.width,
+                count=result.num_bands,
+                dtype=dtype,
+                crs=f'EPSG:{result.srid}',
+                transform=rasterio.transform.from_origin(
+                    result.upper_left_x,
+                    result.upper_left_y,
+                    result.scale_x,
+                    result.scale_y
+                )
+            ) as dataset:
+                # 写入数据
+                dataset.write(raster_data)
+            
+            # 将内存文件写入磁盘
+            with open(temp_file_path, 'wb') as f:
+                f.write(memfile.read())
+        
+        # 检查文件是否成功创建
+        if not os.path.exists(temp_file_path):
+            shutil.rmtree(temp_dir)
+            raise Exception("文件创建失败")
+        
+        return {
+            "message": "栅格数据导出成功",
+            "file_path": temp_file_path,
+            "temp_dir": temp_dir  # 返回临时目录路径,以便后续清理
+        }
+            
+    except Exception as e:
+        if os.path.exists(temp_dir):
+            shutil.rmtree(temp_dir)
+        raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
+
+def export_raster_data_batch(db: Session, raster_ids: List[int]):
+    """批量导出栅格数据为TIFF文件"""
+    try:
+        # 创建临时目录
+        temp_dir = tempfile.mkdtemp()
+        zip_file_path = os.path.join(temp_dir, "raster_batch.zip")
+        
+        # 创建ZIP文件
+        with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
+            for raster_id in raster_ids:
+                try:
+                    # 导出单个栅格数据
+                    result = export_raster_data(db, raster_id)
+                    if result and os.path.exists(result["file_path"]):
+                        # 将文件添加到ZIP中
+                        zipf.write(
+                            result["file_path"],
+                            arcname=f"raster_{raster_id}.tif"
+                        )
+                except Exception as e:
+                    print(f"导出栅格数据 {raster_id} 失败: {str(e)}")
+                    continue
+        
+        # 检查ZIP文件是否成功创建
+        if not os.path.exists(zip_file_path):
+            shutil.rmtree(temp_dir)
+            raise Exception("批量导出文件创建失败")
+        
+        return {
+            "message": "栅格数据批量导出成功",
+            "file_path": zip_file_path,
+            "temp_dir": temp_dir
+        }
+            
+    except Exception as e:
+        if os.path.exists(temp_dir):
+            shutil.rmtree(temp_dir)
+        raise HTTPException(status_code=500, detail=f"批量导出失败: {str(e)}")

+ 168 - 0
app/services/vector_service.py

@@ -0,0 +1,168 @@
+# 矢量数据服务
+
+from fastapi import HTTPException, UploadFile
+from sqlalchemy.orm import Session
+from ..models.vector import VectorData
+import json
+import os
+from datetime import datetime
+from decimal import Decimal
+from typing import List
+import uuid
+import tempfile
+
+class DecimalEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, Decimal):
+            return float(obj)
+        return super(DecimalEncoder, self).default(obj)
+
+def get_vector_data(db: Session, vector_id: int):
+    """通过ID获取一条矢量数据记录"""
+    vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
+    if not vector_data:
+        raise HTTPException(status_code=404, detail="矢量数据不存在")
+    return vector_data
+
+def get_vector_data_batch(db: Session, vector_ids: List[int]):
+    """批量获取矢量数据记录"""
+    vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
+    if not vector_data_list:
+        raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
+    return vector_data_list
+
+async def import_vector_data(file: UploadFile, db: Session) -> dict:
+    """导入GeoJSON文件到数据库"""
+    try:
+        # 读取文件内容
+        content = await file.read()
+        data = json.loads(content)
+        
+        # 验证GeoJSON格式
+        if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
+            raise ValueError("无效的GeoJSON格式")
+        
+        features = data.get("features", [])
+        if not features:
+            raise ValueError("GeoJSON文件中没有要素数据")
+        
+        # 获取表的所有列名
+        columns = [column.name for column in VectorData.__table__.columns]
+        
+        # 导入每个要素
+        imported_count = 0
+        for feature in features:
+            if not isinstance(feature, dict) or feature.get("type") != "Feature":
+                continue
+                
+            # 获取属性
+            properties = feature.get("properties", {})
+            
+            # 创建新记录
+            vector_data = VectorData()
+            
+            # 设置每个字段的值(除了id)
+            for column in columns:
+                if column == 'id':  # 跳过id字段
+                    continue
+                if column in properties:
+                    value = properties[column]
+                    # 如果值是字典或列表,转换为JSON字符串
+                    if isinstance(value, (dict, list)):
+                        value = json.dumps(value, ensure_ascii=False)
+                    setattr(vector_data, column, value)
+            
+            # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
+            geometry = feature.get("geometry")
+            if geometry:
+                geometry_str = json.dumps(geometry, ensure_ascii=False)
+                setattr(vector_data, 'geometry', geometry_str)
+            elif 'geom' in properties:
+                setattr(vector_data, 'geometry', properties['geom'])
+            
+            try:
+                db.add(vector_data)
+                imported_count += 1
+            except Exception as e:
+                continue
+        
+        # 提交事务
+        try:
+            db.commit()
+        except Exception as e:
+            db.rollback()
+            raise ValueError(f"数据库操作失败: {str(e)}")
+        
+        return {
+            "message": f"成功导入 {imported_count} 条记录",
+            "imported_count": imported_count
+        }
+        
+    except json.JSONDecodeError as e:
+        raise ValueError(f"无效的JSON格式: {str(e)}")
+    except Exception as e:
+        db.rollback()
+        raise ValueError(f"导入失败: {str(e)}")
+
+def export_vector_data(db: Session, vector_id: int):
+    """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
+    vector_data = get_vector_data(db, vector_id)
+    return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
+
+def export_vector_data_batch(db: Session, vector_ids: List[int]):
+    """批量导出矢量数据为GeoJSON格式并保存到文件"""
+    vector_data_list = get_vector_data_batch(db, vector_ids)
+    return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
+
+def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
+    """将矢量数据列表导出为GeoJSON文件"""
+    features = []
+    
+    for vector_data in vector_data_list:
+        # 获取表的所有列名
+        columns = [column.name for column in VectorData.__table__.columns]
+        
+        # 构建包含所有列数据的字典
+        data_dict = {}
+        for column in columns:
+            value = getattr(vector_data, column)
+            # 如果值是字符串且可能是JSON,尝试解析
+            if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
+                try:
+                    value = json.loads(value)
+                except:
+                    pass
+            data_dict[column] = value
+        
+        # 创建Feature
+        feature = {
+            "type": "Feature",
+            "properties": data_dict,
+            "geometry": json.loads(vector_data.geometry) if hasattr(vector_data, 'geometry') else None
+        }
+        features.append(feature)
+    
+    # 创建GeoJSON对象
+    geojson = {
+        "type": "FeatureCollection",
+        "features": features
+    }
+    
+    # 创建临时目录
+    temp_dir = tempfile.mkdtemp()
+    
+    # 生成文件名(使用时间戳避免重名)
+    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+    filename = f"{base_filename}_{timestamp}.geojson"
+    file_path = os.path.join(temp_dir, filename)
+    
+    # 保存到文件,使用自定义编码器处理Decimal类型
+    with open(file_path, "w", encoding="utf-8") as f:
+        json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
+    
+    return {
+        "message": "数据导出成功",
+        "file_path": file_path,
+        "temp_dir": temp_dir,  # 返回临时目录路径,以便后续清理
+        "data": geojson
+    }

+ 22 - 0
app/utils/file_validators.py

@@ -0,0 +1,22 @@
+# 文件类型验证器
+from fastapi import UploadFile, HTTPException
+
+def validate_file_type(file: UploadFile) -> str:
+    """验证文件类型并返回文件类型"""
+    filename = file.filename.lower()
+    
+    SUPPORTED_RASTER_TYPES = {'.tif', '.tiff', '.img'}
+    SUPPORTED_VECTOR_TYPES = {'.geojson', '.json', '.shp'}
+    
+    file_ext = filename.split('.')[-1]
+    
+    if f'.{file_ext}' in SUPPORTED_RASTER_TYPES:
+        return 'raster'
+    elif f'.{file_ext}' in SUPPORTED_VECTOR_TYPES:
+        return 'vector'
+    else:
+        supported_formats = list(SUPPORTED_RASTER_TYPES) + list(SUPPORTED_VECTOR_TYPES)
+        raise HTTPException(
+            status_code=400,
+            detail=f"不支持的文件类型。支持的格式有: {', '.join(supported_formats)}"
+        )

+ 5 - 0
main.py

@@ -0,0 +1,5 @@
+from app.main import app
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8000) 

BIN
uninstall.txt