from fastapi import HTTPException, UploadFile from sqlalchemy.orm import Session from ..models.vector import VectorData import json import os import datetime # 统一使用模块导入,避免冲突 from decimal import Decimal from typing import List import uuid import tempfile import struct from sqlalchemy.sql import text import binascii # 导入shapely库用于解析WKB try: from shapely import wkb from shapely.geometry import mapping SHAPELY_AVAILABLE = True except ImportError: SHAPELY_AVAILABLE = False class DecimalEncoder(json.JSONEncoder): """处理Decimal类型的JSON编码器""" def default(self, obj): if isinstance(obj, Decimal): return float(obj) return super(DecimalEncoder, self).default(obj) def get_vector_data(db: Session, vector_id: int): """通过ID获取一条矢量数据记录""" vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first() if not vector_data: raise HTTPException(status_code=404, detail="矢量数据不存在") # 手动构建返回字典 result = {} for column in vector_data.__table__.columns: value = getattr(vector_data, column.name) # 处理特殊类型 if isinstance(value, Decimal): value = float(value) elif isinstance(value, datetime.datetime): # 修复datetime判断 value = value.isoformat() elif str(column.type).startswith('geometry'): # 如果是几何类型,直接使用字符串表示 if value is not None: value = str(value) result[column.name] = value return result def get_vector_data_batch(db: Session, vector_ids: List[int]): """批量获取矢量数据记录""" vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all() if not vector_data_list: raise HTTPException(status_code=404, detail="未找到指定的矢量数据") result = [] for vector_data in vector_data_list: item = {} for column in vector_data.__table__.columns: value = getattr(vector_data, column.name) # 处理特殊类型 if isinstance(value, Decimal): value = float(value) elif isinstance(value, datetime.datetime): # 修复datetime判断 value = value.isoformat() elif str(column.type).startswith('geometry'): if value is not None: value = str(value) item[column.name] = value result.append(item) return result async def import_vector_data(file: UploadFile, db: Session) -> dict: """导入GeoJSON文件到数据库""" try: # 读取文件内容 content = await file.read() data = json.loads(content) # 验证GeoJSON格式 if not isinstance(data, dict) or data.get("type") != "FeatureCollection": raise ValueError("无效的GeoJSON格式") features = data.get("features", []) if not features: raise ValueError("GeoJSON文件中没有要素数据") # 获取表的所有列名 columns = [column.name for column in VectorData.__table__.columns] # 导入每个要素 imported_count = 0 for feature in features: if not isinstance(feature, dict) or feature.get("type") != "Feature": continue # 获取属性 properties = feature.get("properties", {}) # 创建新记录 vector_data = VectorData() # 设置每个字段的值(除了id) for column in columns: if column == 'id': # 跳过id字段 continue if column in properties: value = properties[column] # 如果值是字典或列表,转换为JSON字符串 if isinstance(value, (dict, list)): value = json.dumps(value, ensure_ascii=False) setattr(vector_data, column, value) # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段) geometry = feature.get("geometry") if geometry: geometry_str = json.dumps(geometry, ensure_ascii=False) setattr(vector_data, 'geom', geometry_str) elif 'geom' in properties: setattr(vector_data, 'geometry', properties['geom']) try: db.add(vector_data) imported_count += 1 except Exception as e: continue # 提交事务 try: db.commit() except Exception as e: db.rollback() raise ValueError(f"数据库操作失败: {str(e)}") return { "message": f"成功导入 {imported_count} 条记录", "imported_count": imported_count } except json.JSONDecodeError as e: raise ValueError(f"无效的JSON格式: {str(e)}") except Exception as e: db.rollback() raise ValueError(f"导入失败: {str(e)}") def export_vector_data(db: Session, vector_id: int): """导出指定ID的矢量数据为GeoJSON格式并保存到文件""" vector_data = get_vector_data(db, vector_id) return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata") def export_vector_data_batch(db: Session, vector_ids: List[int]): """批量导出矢量数据为GeoJSON格式并保存到文件""" vector_data_list = get_vector_data_batch(db, vector_ids) return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata") def export_all_vector_data(db: Session, table_name: str = "surveydata"): """导出指定的矢量数据表为GeoJSON格式并保存到文件""" try: query = text(f'SELECT * FROM "{table_name}"') result = db.execute(query) columns = [col.name for col in result.cursor.description] vector_data_list = [dict(zip(columns, row)) for row in result.fetchall()] # 类型修正 return _export_vector_data_to_file(vector_data_list, f"export_{table_name}", table_name) except Exception as e: raise HTTPException( status_code=500, detail=f"查询表{table_name}失败:{str(e)}" ) def parse_geom_field(geom_value) -> dict: """解析 geom 字段为 GeoJSON 格式的 geometry""" try: geom_str = str(geom_value) # 处理PostGIS WKB格式的点数据 if geom_str and geom_str.startswith('0101000020'): binary_geom = bytes.fromhex(geom_str[2:]) byte_order = struct.unpack('B', binary_geom[:1])[0] endian = '<' if byte_order == 1 else '>' if len(binary_geom) >= 1 + 16: coord_bytes = binary_geom[-16:] x, y = struct.unpack(f'{endian}dd', coord_bytes) return { "type": "Point", "coordinates": [x, y] } else: print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}") # 处理WKT格式 elif geom_str and (geom_str.startswith('POINT') or geom_str.startswith('LINESTRING') or geom_str.startswith('POLYGON') or geom_str.startswith('MULTIPOINT') or geom_str.startswith('MULTILINESTRING') or geom_str.startswith('MULTIPOLYGON')): print(f"检测到WKT格式几何数据: {geom_str[:30]}...") return None # 处理EWKT格式 elif geom_str and geom_str.startswith('SRID='): print(f"检测到EWKT格式几何数据: {geom_str[:30]}...") return None # 处理十六进制WKB格式 elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str): print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...") if SHAPELY_AVAILABLE: try: if len(geom_str) % 2 != 0: geom_str = '0' + geom_str binary_data = binascii.unhexlify(geom_str) shape = wkb.loads(binary_data) return mapping(shape) except Exception as e: print(f"Shapely解析WKB失败: {e}") try: from ..database import engine with engine.connect() as connection: query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson") result = connection.execute(query, [geom_str]).fetchone() if result and result.geojson: return json.loads(result.geojson) except Exception as e: print(f"使用PostgreSQL解析WKB失败: {e}") return None else: from ..database import engine try: with engine.connect() as connection: query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson") result = connection.execute(query, [geom_value]).fetchone() if result and result.geojson: return json.loads(result.geojson) except Exception as e: print(f"使用ST_AsGeoJSON转换几何数据失败: {e}") print(f"未识别的几何数据格式: {geom_str[:50]}...") except (ValueError, IndexError, struct.error) as e: print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}") return None def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"): """将矢量数据列表导出为 GeoJSON 文件(修复后)""" try: features = [] for data in vector_data_list: # 1. 处理空间信息(从longitude和latitude生成点坐标) longitude = data.get('longitude') latitude = data.get('latitude') # 跳过经纬度无效的数据 if not (isinstance(longitude, (int, float)) and isinstance(latitude, (int, float))): print(f"跳过无效数据:经度={longitude}(类型{type(longitude)}),纬度={latitude}(类型{type(latitude)})") continue # 生成标准GeoJSON点 geometry geometry = { "type": "Point", "coordinates": [longitude, latitude] # [经度, 纬度] } # 2. 处理属性信息(处理特殊类型) properties = {} for key, value in data.items(): # 处理日期时间类型 if isinstance(value, datetime.datetime): properties[key] = value.strftime("%Y-%m-%d %H:%M:%S") # 处理'nan'特殊值 elif value == 'nan': properties[key] = None # 处理Decimal类型 elif isinstance(value, Decimal): properties[key] = float(value) # 处理其他值 else: properties[key] = value # 3. 组合成GeoJSON要素 feature = { "type": "Feature", "geometry": geometry, "properties": properties } features.append(feature) # 4. 生成完整GeoJSON geojson_data = { "type": "FeatureCollection", "features": features } # 创建临时目录并保存文件 temp_dir = tempfile.mkdtemp() timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{base_filename}_{timestamp}.geojson" file_path = os.path.join(temp_dir, filename) # 写入文件(使用自定义编码器处理Decimal) with open(file_path, "w", encoding="utf-8") as f: json.dump(geojson_data, f, ensure_ascii=False, indent=2, cls=DecimalEncoder) # 返回结果 return { "message": "数据导出成功", "file_path": file_path, "temp_dir": temp_dir, "data": geojson_data } except Exception as e: error_data = data if 'data' in locals() else "未知数据" print(f"生成矢量数据时出错:{str(e)},出错数据:{error_data}") raise