123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- from fastapi import HTTPException, UploadFile
- from sqlalchemy.orm import Session
- from ..models.vector import VectorData
- import json
- import os
- import datetime # 统一使用模块导入,避免冲突
- from decimal import Decimal
- from typing import List
- import uuid
- import tempfile
- import struct
- from sqlalchemy.sql import text
- import binascii
- # 导入shapely库用于解析WKB
- try:
- from shapely import wkb
- from shapely.geometry import mapping
- SHAPELY_AVAILABLE = True
- except ImportError:
- SHAPELY_AVAILABLE = False
- class DecimalEncoder(json.JSONEncoder):
- """处理Decimal类型的JSON编码器"""
- def default(self, obj):
- if isinstance(obj, Decimal):
- return float(obj)
- return super(DecimalEncoder, self).default(obj)
- def get_vector_data(db: Session, vector_id: int):
- """通过ID获取一条矢量数据记录"""
- vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
- if not vector_data:
- raise HTTPException(status_code=404, detail="矢量数据不存在")
-
- # 手动构建返回字典
- result = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime.datetime): # 修复datetime判断
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- # 如果是几何类型,直接使用字符串表示
- if value is not None:
- value = str(value)
- result[column.name] = value
-
- return result
- def get_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量获取矢量数据记录"""
- vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
- if not vector_data_list:
- raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
-
- result = []
- for vector_data in vector_data_list:
- item = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime.datetime): # 修复datetime判断
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- if value is not None:
- value = str(value)
- item[column.name] = value
- result.append(item)
-
- return result
- async def import_vector_data(file: UploadFile, db: Session) -> dict:
- """导入GeoJSON文件到数据库"""
- try:
- # 读取文件内容
- content = await file.read()
- data = json.loads(content)
- # 验证GeoJSON格式
- if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
- raise ValueError("无效的GeoJSON格式")
- features = data.get("features", [])
- if not features:
- raise ValueError("GeoJSON文件中没有要素数据")
- # 获取表的所有列名
- columns = [column.name for column in VectorData.__table__.columns]
- # 导入每个要素
- imported_count = 0
- for feature in features:
- if not isinstance(feature, dict) or feature.get("type") != "Feature":
- continue
- # 获取属性
- properties = feature.get("properties", {})
- # 创建新记录
- vector_data = VectorData()
- # 设置每个字段的值(除了id)
- for column in columns:
- if column == 'id': # 跳过id字段
- continue
- if column in properties:
- value = properties[column]
- # 如果值是字典或列表,转换为JSON字符串
- if isinstance(value, (dict, list)):
- value = json.dumps(value, ensure_ascii=False)
- setattr(vector_data, column, value)
- # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
- geometry = feature.get("geometry")
- if geometry:
- geometry_str = json.dumps(geometry, ensure_ascii=False)
- setattr(vector_data, 'geom', geometry_str)
- elif 'geom' in properties:
- setattr(vector_data, 'geometry', properties['geom'])
- try:
- db.add(vector_data)
- imported_count += 1
- except Exception as e:
- continue
- # 提交事务
- try:
- db.commit()
- except Exception as e:
- db.rollback()
- raise ValueError(f"数据库操作失败: {str(e)}")
- return {
- "message": f"成功导入 {imported_count} 条记录",
- "imported_count": imported_count
- }
- except json.JSONDecodeError as e:
- raise ValueError(f"无效的JSON格式: {str(e)}")
- except Exception as e:
- db.rollback()
- raise ValueError(f"导入失败: {str(e)}")
- def export_vector_data(db: Session, vector_id: int):
- """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
- vector_data = get_vector_data(db, vector_id)
- return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata")
- def export_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量导出矢量数据为GeoJSON格式并保存到文件"""
- vector_data_list = get_vector_data_batch(db, vector_ids)
- return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata")
- def export_all_vector_data(db: Session, table_name: str = "surveydata"):
- """导出指定的矢量数据表为GeoJSON格式并保存到文件"""
- try:
- query = text(f'SELECT * FROM "{table_name}"')
- result = db.execute(query)
- columns = [col.name for col in result.cursor.description]
- vector_data_list = [dict(zip(columns, row)) for row in result.fetchall()] # 类型修正
- return _export_vector_data_to_file(vector_data_list, f"export_{table_name}", table_name)
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"查询表{table_name}失败:{str(e)}"
- )
- def parse_geom_field(geom_value) -> dict:
- """解析 geom 字段为 GeoJSON 格式的 geometry"""
- try:
- geom_str = str(geom_value)
-
- # 处理PostGIS WKB格式的点数据
- if geom_str and geom_str.startswith('0101000020'):
- binary_geom = bytes.fromhex(geom_str[2:])
- byte_order = struct.unpack('B', binary_geom[:1])[0]
- endian = '<' if byte_order == 1 else '>'
- if len(binary_geom) >= 1 + 16:
- coord_bytes = binary_geom[-16:]
- x, y = struct.unpack(f'{endian}dd', coord_bytes)
- return {
- "type": "Point",
- "coordinates": [x, y]
- }
- else:
- print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}")
- # 处理WKT格式
- elif geom_str and (geom_str.startswith('POINT') or
- geom_str.startswith('LINESTRING') or
- geom_str.startswith('POLYGON') or
- geom_str.startswith('MULTIPOINT') or
- geom_str.startswith('MULTILINESTRING') or
- geom_str.startswith('MULTIPOLYGON')):
- print(f"检测到WKT格式几何数据: {geom_str[:30]}...")
- return None
- # 处理EWKT格式
- elif geom_str and geom_str.startswith('SRID='):
- print(f"检测到EWKT格式几何数据: {geom_str[:30]}...")
- return None
- # 处理十六进制WKB格式
- elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str):
- print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...")
-
- if SHAPELY_AVAILABLE:
- try:
- if len(geom_str) % 2 != 0:
- geom_str = '0' + geom_str
- binary_data = binascii.unhexlify(geom_str)
- shape = wkb.loads(binary_data)
- return mapping(shape)
- except Exception as e:
- print(f"Shapely解析WKB失败: {e}")
-
- try:
- from ..database import engine
- with engine.connect() as connection:
- query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson")
- result = connection.execute(query, [geom_str]).fetchone()
- if result and result.geojson:
- return json.loads(result.geojson)
- except Exception as e:
- print(f"使用PostgreSQL解析WKB失败: {e}")
- return None
- else:
- from ..database import engine
- try:
- with engine.connect() as connection:
- query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson")
- result = connection.execute(query, [geom_value]).fetchone()
- if result and result.geojson:
- return json.loads(result.geojson)
- except Exception as e:
- print(f"使用ST_AsGeoJSON转换几何数据失败: {e}")
- print(f"未识别的几何数据格式: {geom_str[:50]}...")
-
- except (ValueError, IndexError, struct.error) as e:
- print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}")
-
- return None
- def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"):
- """将矢量数据列表导出为 GeoJSON 文件(修复后)"""
- try:
- features = []
- for data in vector_data_list:
- # 1. 处理空间信息(从longitude和latitude生成点坐标)
- longitude = data.get('longitude')
- latitude = data.get('latitude')
-
- # 跳过经纬度无效的数据
- if not (isinstance(longitude, (int, float)) and isinstance(latitude, (int, float))):
- print(f"跳过无效数据:经度={longitude}(类型{type(longitude)}),纬度={latitude}(类型{type(latitude)})")
- continue
-
- # 生成标准GeoJSON点 geometry
- geometry = {
- "type": "Point",
- "coordinates": [longitude, latitude] # [经度, 纬度]
- }
-
- # 2. 处理属性信息(处理特殊类型)
- properties = {}
- for key, value in data.items():
- # 处理日期时间类型
- if isinstance(value, datetime.datetime):
- properties[key] = value.strftime("%Y-%m-%d %H:%M:%S")
- # 处理'nan'特殊值
- elif value == 'nan':
- properties[key] = None
- # 处理Decimal类型
- elif isinstance(value, Decimal):
- properties[key] = float(value)
- # 处理其他值
- else:
- properties[key] = value
-
- # 3. 组合成GeoJSON要素
- feature = {
- "type": "Feature",
- "geometry": geometry,
- "properties": properties
- }
- features.append(feature)
-
- # 4. 生成完整GeoJSON
- geojson_data = {
- "type": "FeatureCollection",
- "features": features
- }
-
- # 创建临时目录并保存文件
- temp_dir = tempfile.mkdtemp()
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = f"{base_filename}_{timestamp}.geojson"
- file_path = os.path.join(temp_dir, filename)
-
- # 写入文件(使用自定义编码器处理Decimal)
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(geojson_data, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
-
- # 返回结果
- return {
- "message": "数据导出成功",
- "file_path": file_path,
- "temp_dir": temp_dir,
- "data": geojson_data
- }
-
- except Exception as e:
- error_data = data if 'data' in locals() else "未知数据"
- print(f"生成矢量数据时出错:{str(e)},出错数据:{error_data}")
- raise
|