from fastapi import HTTPException, UploadFile from sqlalchemy.orm import Session from ..models.vector import VectorData import json import os from datetime import datetime from decimal import Decimal from typing import List import uuid import tempfile import struct from sqlalchemy.sql import text class DecimalEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Decimal): return float(obj) return super(DecimalEncoder, self).default(obj) def get_vector_data(db: Session, vector_id: int): """通过ID获取一条矢量数据记录""" vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first() if not vector_data: raise HTTPException(status_code=404, detail="矢量数据不存在") # 手动构建返回字典 result = {} for column in vector_data.__table__.columns: value = getattr(vector_data, column.name) # 处理特殊类型 if isinstance(value, Decimal): value = float(value) elif isinstance(value, datetime): value = value.isoformat() elif str(column.type).startswith('geometry'): # 如果是几何类型,直接使用字符串表示 if value is not None: value = str(value) result[column.name] = value return result def get_vector_data_batch(db: Session, vector_ids: List[int]): """批量获取矢量数据记录""" vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all() if not vector_data_list: raise HTTPException(status_code=404, detail="未找到指定的矢量数据") result = [] for vector_data in vector_data_list: item = {} for column in vector_data.__table__.columns: value = getattr(vector_data, column.name) # 处理特殊类型 if isinstance(value, Decimal): value = float(value) elif isinstance(value, datetime): value = value.isoformat() elif str(column.type).startswith('geometry'): # 如果是几何类型,直接使用字符串表示 if value is not None: value = str(value) item[column.name] = value result.append(item) return result async def import_vector_data(file: UploadFile, db: Session) -> dict: """导入GeoJSON文件到数据库""" try: # 读取文件内容 content = await file.read() data = json.loads(content) # 验证GeoJSON格式 if not isinstance(data, dict) or data.get("type") != "FeatureCollection": raise ValueError("无效的GeoJSON格式") features = data.get("features", []) if not features: raise ValueError("GeoJSON文件中没有要素数据") # 获取表的所有列名 columns = [column.name for column in VectorData.__table__.columns] # 导入每个要素 imported_count = 0 for feature in features: if not isinstance(feature, dict) or feature.get("type") != "Feature": continue # 获取属性 properties = feature.get("properties", {}) # 创建新记录 vector_data = VectorData() # 设置每个字段的值(除了id) for column in columns: if column == 'id': # 跳过id字段 continue if column in properties: value = properties[column] # 如果值是字典或列表,转换为JSON字符串 if isinstance(value, (dict, list)): value = json.dumps(value, ensure_ascii=False) setattr(vector_data, column, value) # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段) geometry = feature.get("geometry") if geometry: geometry_str = json.dumps(geometry, ensure_ascii=False) setattr(vector_data, 'geometry', geometry_str) elif 'geom' in properties: setattr(vector_data, 'geometry', properties['geom']) try: db.add(vector_data) imported_count += 1 except Exception as e: continue # 提交事务 try: db.commit() except Exception as e: db.rollback() raise ValueError(f"数据库操作失败: {str(e)}") return { "message": f"成功导入 {imported_count} 条记录", "imported_count": imported_count } except json.JSONDecodeError as e: raise ValueError(f"无效的JSON格式: {str(e)}") except Exception as e: db.rollback() raise ValueError(f"导入失败: {str(e)}") def export_vector_data(db: Session, vector_id: int): """导出指定ID的矢量数据为GeoJSON格式并保存到文件""" vector_data = get_vector_data(db, vector_id) return _export_vector_data_to_file([vector_data], f"export_{vector_id}") def export_vector_data_batch(db: Session, vector_ids: List[int]): """批量导出矢量数据为GeoJSON格式并保存到文件""" vector_data_list = get_vector_data_batch(db, vector_ids) return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}") def export_all_vector_data(db: Session, table_name: str = "surveydata"): """导出指定的矢量数据表为GeoJSON格式并保存到文件 Args: db (Session): 数据库会话 table_name (str): 要导出的矢量数据表名,默认为'surveydata' Returns: dict: 包含导出文件路径和临时目录的字典 """ # 使用动态表名查询 query = text(f"SELECT * FROM {table_name}") vector_data_list = db.execute(query).fetchall() # 如果没有数据,抛出异常 if not vector_data_list: raise HTTPException(status_code=404, detail=f"表 {table_name} 中没有矢量数据") # 调用现有的导出函数 return _export_vector_data_to_file(vector_data_list, f"export_{table_name}") def parse_geom_field(geom_value) -> dict: """ 解析 geom 字段为 GeoJSON 格式的 geometry。 如果解析失败,返回 None。 """ try: # 将 geom_value 转换为字符串 geom_str = str(geom_value) if geom_str and geom_str.startswith('0101000020'): # 去掉前两个字符(字节序标记),并转换为字节对象 binary_geom = bytes.fromhex(geom_str[2:]) # 解析字节序(前1个字节) byte_order = struct.unpack('B', binary_geom[:1])[0] endian = '<' if byte_order == 1 else '>' # 检查数据长度是否足够解析坐标 if len(binary_geom) >= 1 + 16: # 从数据末尾往前找 16 字节作为坐标 coord_bytes = binary_geom[-16:] x, y = struct.unpack(f'{endian}dd', coord_bytes) return { "type": "Point", "coordinates": [x, y] } else: print(f"Insufficient data in geom value: {geom_str}. Length: {len(binary_geom)}") except (ValueError, IndexError, struct.error) as e: print(f"Failed to parse geom field: {geom_str if 'geom_str' in locals() else geom_value}. Error: {e}") return None def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str): """将矢量数据列表导出为 GeoJSON 文件""" features = [] for vector_data in vector_data_list: # 获取表的所有列名 columns = [column.name for column in VectorData.__table__.columns] # 构建包含所有列数据的字典 data_dict = {} for column in columns: value = getattr(vector_data, column) # 如果值是字符串且可能是 JSON,尝试解析 if isinstance(value, str) and (value.startswith('{') or value.startswith('[')): try: value = json.loads(value) except: pass # 跳过 geom 字段,后续单独处理 if column != 'geom': data_dict[column] = value # 解析 geom 字段为 GeoJSON 格式的 geometry geometry = None if hasattr(vector_data, 'geom'): geom_value = getattr(vector_data, 'geom') geometry = parse_geom_field(geom_value) # 创建 Feature feature = { "type": "Feature", "properties": data_dict, "geometry": geometry } features.append(feature) # 创建 GeoJSON 对象 geojson = { "type": "FeatureCollection", "features": features } # 创建临时目录 temp_dir = tempfile.mkdtemp() # 生成文件名(使用时间戳避免重名) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{base_filename}_{timestamp}.geojson" file_path = os.path.join(temp_dir, filename) # 保存到文件,使用自定义编码器处理 Decimal 类型 with open(file_path, "w", encoding="utf-8") as f: json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder) return { "message": "数据导出成功", "file_path": file_path, "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理 "data": geojson }