Browse Source

增加获取整张表矢量数据接口,更改导出数据结构

tangbengaoyuan 3 months ago
parent
commit
f856d6f13b
2 changed files with 113 additions and 65 deletions
  1. 19 0
      app/api/vector.py
  2. 94 65
      app/services/vector_service.py

+ 19 - 0
app/api/vector.py

@@ -43,6 +43,25 @@ async def export_vector_data(vector_id: int, db: Session = Depends(get_db)):
         background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
         background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
     )
     )
 
 
+@router.get("/export/all", summary="导出所有矢量数据", description="将整个矢量数据表导出为GeoJSON文件")
+async def export_all_vector_data_api(db: Session = Depends(get_db)):
+    """导出整个矢量数据表为GeoJSON文件"""
+    result = vector_service.export_all_vector_data(db)
+    
+    # 检查文件是否存在
+    if not os.path.exists(result["file_path"]):
+        if "temp_dir" in result and os.path.exists(result["temp_dir"]):
+            shutil.rmtree(result["temp_dir"])
+        raise HTTPException(status_code=404, detail="导出文件不存在")
+    
+    # 返回文件下载
+    return FileResponse(
+        path=result["file_path"],
+        filename=os.path.basename(result["file_path"]),
+        media_type="application/json",
+        background=BackgroundTasks().add_task(shutil.rmtree, result["temp_dir"]) if "temp_dir" in result else None
+    )
+
 @router.post("/export/batch", summary="批量导出矢量数据", description="将多个ID的矢量数据批量导出为GeoJSON文件")
 @router.post("/export/batch", summary="批量导出矢量数据", description="将多个ID的矢量数据批量导出为GeoJSON文件")
 async def export_vector_data_batch(vector_ids: List[int], db: Session = Depends(get_db)):
 async def export_vector_data_batch(vector_ids: List[int], db: Session = Depends(get_db)):
     """批量导出矢量数据为GeoJSON文件"""
     """批量导出矢量数据为GeoJSON文件"""

+ 94 - 65
app/services/vector_service.py

@@ -1,5 +1,3 @@
-# 矢量数据服务
-
 from fastapi import HTTPException, UploadFile
 from fastapi import HTTPException, UploadFile
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 from ..models.vector import VectorData
 from ..models.vector import VectorData
@@ -10,6 +8,8 @@ from decimal import Decimal
 from typing import List
 from typing import List
 import uuid
 import uuid
 import tempfile
 import tempfile
+import struct
+
 
 
 class DecimalEncoder(json.JSONEncoder):
 class DecimalEncoder(json.JSONEncoder):
     def default(self, obj):
     def default(self, obj):
@@ -17,53 +17,22 @@ class DecimalEncoder(json.JSONEncoder):
             return float(obj)
             return float(obj)
         return super(DecimalEncoder, self).default(obj)
         return super(DecimalEncoder, self).default(obj)
 
 
+
 def get_vector_data(db: Session, vector_id: int):
 def get_vector_data(db: Session, vector_id: int):
     """通过ID获取一条矢量数据记录"""
     """通过ID获取一条矢量数据记录"""
     vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
     vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
     if not vector_data:
     if not vector_data:
         raise HTTPException(status_code=404, detail="矢量数据不存在")
         raise HTTPException(status_code=404, detail="矢量数据不存在")
-    
-    # 手动构建返回字典
-    result = {}
-    for column in vector_data.__table__.columns:
-        value = getattr(vector_data, column.name)
-        # 处理特殊类型
-        if isinstance(value, Decimal):
-            value = float(value)
-        elif isinstance(value, datetime):
-            value = value.isoformat()
-        elif str(column.type).startswith('geometry'):
-            # 如果是几何类型,直接使用字符串表示
-            if value is not None:
-                value = str(value)
-        result[column.name] = value
-    
-    return result
+    return vector_data
+
 
 
 def get_vector_data_batch(db: Session, vector_ids: List[int]):
 def get_vector_data_batch(db: Session, vector_ids: List[int]):
     """批量获取矢量数据记录"""
     """批量获取矢量数据记录"""
     vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
     vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
     if not vector_data_list:
     if not vector_data_list:
         raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
         raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
-    
-    result = []
-    for vector_data in vector_data_list:
-        item = {}
-        for column in vector_data.__table__.columns:
-            value = getattr(vector_data, column.name)
-            # 处理特殊类型
-            if isinstance(value, Decimal):
-                value = float(value)
-            elif isinstance(value, datetime):
-                value = value.isoformat()
-            elif str(column.type).startswith('geometry'):
-                # 如果是几何类型,直接使用字符串表示
-                if value is not None:
-                    value = str(value)
-            item[column.name] = value
-        result.append(item)
-    
-    return result
+    return vector_data_list
+
 
 
 async def import_vector_data(file: UploadFile, db: Session) -> dict:
 async def import_vector_data(file: UploadFile, db: Session) -> dict:
     """导入GeoJSON文件到数据库"""
     """导入GeoJSON文件到数据库"""
@@ -71,30 +40,30 @@ async def import_vector_data(file: UploadFile, db: Session) -> dict:
         # 读取文件内容
         # 读取文件内容
         content = await file.read()
         content = await file.read()
         data = json.loads(content)
         data = json.loads(content)
-        
+
         # 验证GeoJSON格式
         # 验证GeoJSON格式
         if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
         if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
             raise ValueError("无效的GeoJSON格式")
             raise ValueError("无效的GeoJSON格式")
-        
+
         features = data.get("features", [])
         features = data.get("features", [])
         if not features:
         if not features:
             raise ValueError("GeoJSON文件中没有要素数据")
             raise ValueError("GeoJSON文件中没有要素数据")
-        
+
         # 获取表的所有列名
         # 获取表的所有列名
         columns = [column.name for column in VectorData.__table__.columns]
         columns = [column.name for column in VectorData.__table__.columns]
-        
+
         # 导入每个要素
         # 导入每个要素
         imported_count = 0
         imported_count = 0
         for feature in features:
         for feature in features:
             if not isinstance(feature, dict) or feature.get("type") != "Feature":
             if not isinstance(feature, dict) or feature.get("type") != "Feature":
                 continue
                 continue
-                
+
             # 获取属性
             # 获取属性
             properties = feature.get("properties", {})
             properties = feature.get("properties", {})
-            
+
             # 创建新记录
             # 创建新记录
             vector_data = VectorData()
             vector_data = VectorData()
-            
+
             # 设置每个字段的值(除了id)
             # 设置每个字段的值(除了id)
             for column in columns:
             for column in columns:
                 if column == 'id':  # 跳过id字段
                 if column == 'id':  # 跳过id字段
@@ -105,7 +74,7 @@ async def import_vector_data(file: UploadFile, db: Session) -> dict:
                     if isinstance(value, (dict, list)):
                     if isinstance(value, (dict, list)):
                         value = json.dumps(value, ensure_ascii=False)
                         value = json.dumps(value, ensure_ascii=False)
                     setattr(vector_data, column, value)
                     setattr(vector_data, column, value)
-            
+
             # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
             # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
             geometry = feature.get("geometry")
             geometry = feature.get("geometry")
             if geometry:
             if geometry:
@@ -113,90 +82,150 @@ async def import_vector_data(file: UploadFile, db: Session) -> dict:
                 setattr(vector_data, 'geometry', geometry_str)
                 setattr(vector_data, 'geometry', geometry_str)
             elif 'geom' in properties:
             elif 'geom' in properties:
                 setattr(vector_data, 'geometry', properties['geom'])
                 setattr(vector_data, 'geometry', properties['geom'])
-            
+
             try:
             try:
                 db.add(vector_data)
                 db.add(vector_data)
                 imported_count += 1
                 imported_count += 1
             except Exception as e:
             except Exception as e:
                 continue
                 continue
-        
+
         # 提交事务
         # 提交事务
         try:
         try:
             db.commit()
             db.commit()
         except Exception as e:
         except Exception as e:
             db.rollback()
             db.rollback()
             raise ValueError(f"数据库操作失败: {str(e)}")
             raise ValueError(f"数据库操作失败: {str(e)}")
-        
+
         return {
         return {
             "message": f"成功导入 {imported_count} 条记录",
             "message": f"成功导入 {imported_count} 条记录",
             "imported_count": imported_count
             "imported_count": imported_count
         }
         }
-        
+
     except json.JSONDecodeError as e:
     except json.JSONDecodeError as e:
         raise ValueError(f"无效的JSON格式: {str(e)}")
         raise ValueError(f"无效的JSON格式: {str(e)}")
     except Exception as e:
     except Exception as e:
         db.rollback()
         db.rollback()
         raise ValueError(f"导入失败: {str(e)}")
         raise ValueError(f"导入失败: {str(e)}")
 
 
+
 def export_vector_data(db: Session, vector_id: int):
 def export_vector_data(db: Session, vector_id: int):
     """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
     """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
     vector_data = get_vector_data(db, vector_id)
     vector_data = get_vector_data(db, vector_id)
     return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
     return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
 
 
+
 def export_vector_data_batch(db: Session, vector_ids: List[int]):
 def export_vector_data_batch(db: Session, vector_ids: List[int]):
     """批量导出矢量数据为GeoJSON格式并保存到文件"""
     """批量导出矢量数据为GeoJSON格式并保存到文件"""
     vector_data_list = get_vector_data_batch(db, vector_ids)
     vector_data_list = get_vector_data_batch(db, vector_ids)
     return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
     return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
 
 
+
+def export_all_vector_data(db: Session):
+    """导出整个矢量数据表为GeoJSON格式并保存到文件"""
+    # 查询所有矢量数据
+    vector_data_list = db.query(VectorData).all()
+
+    # 如果没有数据,抛出异常
+    if not vector_data_list:
+        raise HTTPException(status_code=404, detail="数据库中没有矢量数据")
+
+    # 调用现有的导出函数
+    return _export_vector_data_to_file(vector_data_list, "export_all")
+
+
+
+def parse_geom_field(geom_value) -> dict:
+    """
+    解析 geom 字段为 GeoJSON 格式的 geometry。
+    如果解析失败,返回 None。
+    """
+    try:
+        # 将 geom_value 转换为字符串
+        geom_str = str(geom_value)
+        if geom_str and geom_str.startswith('0101000020'):
+            # 去掉前两个字符(字节序标记),并转换为字节对象
+            binary_geom = bytes.fromhex(geom_str[2:])
+
+            # 解析字节序(前1个字节)
+            byte_order = struct.unpack('B', binary_geom[:1])[0]
+            endian = '<' if byte_order == 1 else '>'
+
+            # 检查数据长度是否足够解析坐标
+            if len(binary_geom) >= 1 + 16:
+                # 从数据末尾往前找 16 字节作为坐标
+                coord_bytes = binary_geom[-16:]
+                x, y = struct.unpack(f'{endian}dd', coord_bytes)
+                return {
+                    "type": "Point",
+                    "coordinates": [x, y]
+                }
+            else:
+                print(f"Insufficient data in geom value: {geom_str}. Length: {len(binary_geom)}")
+    except (ValueError, IndexError, struct.error) as e:
+        print(f"Failed to parse geom field: {geom_str if 'geom_str' in locals() else geom_value}. Error: {e}")
+    return None
+
+
+
 def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
 def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
-    """将矢量数据列表导出为GeoJSON文件"""
+    """将矢量数据列表导出为 GeoJSON 文件"""
     features = []
     features = []
-    
+
     for vector_data in vector_data_list:
     for vector_data in vector_data_list:
         # 获取表的所有列名
         # 获取表的所有列名
         columns = [column.name for column in VectorData.__table__.columns]
         columns = [column.name for column in VectorData.__table__.columns]
-        
+
         # 构建包含所有列数据的字典
         # 构建包含所有列数据的字典
         data_dict = {}
         data_dict = {}
         for column in columns:
         for column in columns:
             value = getattr(vector_data, column)
             value = getattr(vector_data, column)
-            # 如果值是字符串且可能是JSON,尝试解析
+            # 如果值是字符串且可能是 JSON,尝试解析
             if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
             if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
                 try:
                 try:
                     value = json.loads(value)
                     value = json.loads(value)
                 except:
                 except:
                     pass
                     pass
-            data_dict[column] = value
-        
-        # 创建Feature
+            # 跳过 geom 字段,后续单独处理
+            if column != 'geom':
+                data_dict[column] = value
+
+        # 解析 geom 字段为 GeoJSON 格式的 geometry
+        geometry = None
+        if hasattr(vector_data, 'geom'):
+            geom_value = getattr(vector_data, 'geom')
+            geometry = parse_geom_field(geom_value)
+
+        # 创建 Feature
         feature = {
         feature = {
             "type": "Feature",
             "type": "Feature",
             "properties": data_dict,
             "properties": data_dict,
-            "geometry": json.loads(vector_data.geometry) if hasattr(vector_data, 'geometry') else None
+            "geometry": geometry
         }
         }
         features.append(feature)
         features.append(feature)
-    
-    # 创建GeoJSON对象
+
+    # 创建 GeoJSON 对象
     geojson = {
     geojson = {
         "type": "FeatureCollection",
         "type": "FeatureCollection",
         "features": features
         "features": features
     }
     }
-    
+
     # 创建临时目录
     # 创建临时目录
     temp_dir = tempfile.mkdtemp()
     temp_dir = tempfile.mkdtemp()
-    
+
     # 生成文件名(使用时间戳避免重名)
     # 生成文件名(使用时间戳避免重名)
     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
     filename = f"{base_filename}_{timestamp}.geojson"
     filename = f"{base_filename}_{timestamp}.geojson"
     file_path = os.path.join(temp_dir, filename)
     file_path = os.path.join(temp_dir, filename)
-    
-    # 保存到文件,使用自定义编码器处理Decimal类型
+
+    # 保存到文件,使用自定义编码器处理 Decimal 类型
     with open(file_path, "w", encoding="utf-8") as f:
     with open(file_path, "w", encoding="utf-8") as f:
         json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
         json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
-    
+
     return {
     return {
         "message": "数据导出成功",
         "message": "数据导出成功",
         "file_path": file_path,
         "file_path": file_path,
         "temp_dir": temp_dir,  # 返回临时目录路径,以便后续清理
         "temp_dir": temp_dir,  # 返回临时目录路径,以便后续清理
         "data": geojson
         "data": geojson
     }
     }
+
+