123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- from fastapi import HTTPException, UploadFile
- from sqlalchemy.orm import Session
- from ..models.vector import VectorData
- import json
- import os
- from datetime import datetime
- from decimal import Decimal
- from typing import List
- import uuid
- import tempfile
- import struct
- from sqlalchemy.sql import text
- import binascii
- # 导入shapely库用于解析WKB
- try:
- from shapely import wkb
- from shapely.geometry import mapping
- SHAPELY_AVAILABLE = True
- except ImportError:
- SHAPELY_AVAILABLE = False
- class DecimalEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, Decimal):
- return float(obj)
- return super(DecimalEncoder, self).default(obj)
- def get_vector_data(db: Session, vector_id: int):
- """通过ID获取一条矢量数据记录"""
- vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
- if not vector_data:
- raise HTTPException(status_code=404, detail="矢量数据不存在")
-
- # 手动构建返回字典
- result = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime):
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- # 如果是几何类型,直接使用字符串表示
- if value is not None:
- value = str(value)
- result[column.name] = value
-
- return result
- def get_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量获取矢量数据记录"""
- vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
- if not vector_data_list:
- raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
-
- result = []
- for vector_data in vector_data_list:
- item = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime):
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- # 如果是几何类型,直接使用字符串表示
- if value is not None:
- value = str(value)
- item[column.name] = value
- result.append(item)
-
- return result
- async def import_vector_data(file: UploadFile, db: Session) -> dict:
- """导入GeoJSON文件到数据库"""
- try:
- # 读取文件内容
- content = await file.read()
- data = json.loads(content)
- # 验证GeoJSON格式
- if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
- raise ValueError("无效的GeoJSON格式")
- features = data.get("features", [])
- if not features:
- raise ValueError("GeoJSON文件中没有要素数据")
- # 获取表的所有列名
- columns = [column.name for column in VectorData.__table__.columns]
- # 导入每个要素
- imported_count = 0
- for feature in features:
- if not isinstance(feature, dict) or feature.get("type") != "Feature":
- continue
- # 获取属性
- properties = feature.get("properties", {})
- # 创建新记录
- vector_data = VectorData()
- # 设置每个字段的值(除了id)
- for column in columns:
- if column == 'id': # 跳过id字段
- continue
- if column in properties:
- value = properties[column]
- # 如果值是字典或列表,转换为JSON字符串
- if isinstance(value, (dict, list)):
- value = json.dumps(value, ensure_ascii=False)
- setattr(vector_data, column, value)
- # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
- geometry = feature.get("geometry")
- if geometry:
- geometry_str = json.dumps(geometry, ensure_ascii=False)
- setattr(vector_data, 'geometry', geometry_str)
- elif 'geom' in properties:
- setattr(vector_data, 'geometry', properties['geom'])
- try:
- db.add(vector_data)
- imported_count += 1
- except Exception as e:
- continue
- # 提交事务
- try:
- db.commit()
- except Exception as e:
- db.rollback()
- raise ValueError(f"数据库操作失败: {str(e)}")
- return {
- "message": f"成功导入 {imported_count} 条记录",
- "imported_count": imported_count
- }
- except json.JSONDecodeError as e:
- raise ValueError(f"无效的JSON格式: {str(e)}")
- except Exception as e:
- db.rollback()
- raise ValueError(f"导入失败: {str(e)}")
- def export_vector_data(db: Session, vector_id: int):
- """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
- vector_data = get_vector_data(db, vector_id)
- return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata")
- def export_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量导出矢量数据为GeoJSON格式并保存到文件"""
- vector_data_list = get_vector_data_batch(db, vector_ids)
- return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata")
- def export_all_vector_data(db: Session, table_name: str = "surveydata"):
- """导出指定的矢量数据表为GeoJSON格式并保存到文件
-
- Args:
- db (Session): 数据库会话
- table_name (str): 要导出的矢量数据表名,默认为'surveydata'
-
- Returns:
- dict: 包含导出文件路径和临时目录的字典
- """
- # 使用动态表名查询
- query = text(f"SELECT * FROM {table_name}")
- vector_data_list = db.execute(query).fetchall()
- # 如果没有数据,抛出异常
- if not vector_data_list:
- raise HTTPException(status_code=404, detail=f"表 {table_name} 中没有矢量数据")
- # 调用现有的导出函数
- return _export_vector_data_to_file(vector_data_list, f"export_{table_name}", table_name)
- def parse_geom_field(geom_value) -> dict:
- """
- 解析 geom 字段为 GeoJSON 格式的 geometry。
- 如果解析失败,返回 None。
- """
- try:
- # 将 geom_value 转换为字符串
- geom_str = str(geom_value)
-
- # 处理PostGIS WKB格式的点数据
- if geom_str and geom_str.startswith('0101000020'):
- # 去掉前两个字符(字节序标记),并转换为字节对象
- binary_geom = bytes.fromhex(geom_str[2:])
- # 解析字节序(前1个字节)
- byte_order = struct.unpack('B', binary_geom[:1])[0]
- endian = '<' if byte_order == 1 else '>'
- # 检查数据长度是否足够解析坐标
- if len(binary_geom) >= 1 + 16:
- # 从数据末尾往前找 16 字节作为坐标
- coord_bytes = binary_geom[-16:]
- x, y = struct.unpack(f'{endian}dd', coord_bytes)
- return {
- "type": "Point",
- "coordinates": [x, y]
- }
- else:
- print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}")
- # 处理PostgreSQL/PostGIS的Well-Known Text (WKT)格式
- elif geom_str and (geom_str.startswith('POINT') or
- geom_str.startswith('LINESTRING') or
- geom_str.startswith('POLYGON') or
- geom_str.startswith('MULTIPOINT') or
- geom_str.startswith('MULTILINESTRING') or
- geom_str.startswith('MULTIPOLYGON')):
- # 这里我们需要依赖PostgreSQL服务器将WKT转换为GeoJSON
- # 在实际部署中,应该使用数据库函数如ST_AsGeoJSON()
- print(f"检测到WKT格式几何数据: {geom_str[:30]}...")
- return None
- # 处理EWKT (Extended Well-Known Text)格式,如SRID=4326;POINT(...)
- elif geom_str and geom_str.startswith('SRID='):
- print(f"检测到EWKT格式几何数据: {geom_str[:30]}...")
- return None
- # 处理十六进制WKB格式
- elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str):
- print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...")
-
- # 使用Shapely库解析WKB (首选方法)
- if SHAPELY_AVAILABLE:
- try:
- # 如果字符串长度为奇数,可能需要在前面添加一个"0"
- if len(geom_str) % 2 != 0:
- geom_str = '0' + geom_str
-
- # 将十六进制字符串转换为二进制数据
- binary_data = binascii.unhexlify(geom_str)
-
- # 使用Shapely解析WKB并转换为GeoJSON
- shape = wkb.loads(binary_data)
- return mapping(shape)
- except Exception as e:
- print(f"Shapely解析WKB失败: {e}")
-
- # 使用PostGIS函数进行解析
- try:
- from ..database import engine
-
- with engine.connect() as connection:
- # 使用PostgreSQL/PostGIS的ST_GeomFromWKB函数和ST_AsGeoJSON函数
- query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson")
- result = connection.execute(query, [geom_str]).fetchone()
-
- if result and result.geojson:
- return json.loads(result.geojson)
- except Exception as e:
- print(f"使用PostgreSQL解析WKB失败: {e}")
-
- return None
- else:
- # 可能是使用PostGIS扩展的内部二进制格式
- from sqlalchemy.sql import text
- from ..database import engine
-
- try:
- # 使用PostgreSQL/PostGIS的ST_AsGeoJSON函数直接转换
- with engine.connect() as connection:
- # 尝试安全地传递geom_value
- # 注意:这种方法依赖于数据库连接和PostGIS扩展
- query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson")
- result = connection.execute(query, [geom_value]).fetchone()
-
- if result and result.geojson:
- return json.loads(result.geojson)
- except Exception as e:
- print(f"使用ST_AsGeoJSON转换几何数据失败: {e}")
-
- print(f"未识别的几何数据格式: {geom_str[:50]}...")
-
- except (ValueError, IndexError, struct.error) as e:
- print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}")
-
- return None
- def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"):
- """将矢量数据列表导出为 GeoJSON 文件
-
- Args:
- vector_data_list: 矢量数据列表,可能是ORM对象或SQLAlchemy行对象
- base_filename: 基础文件名
- table_name: 表名,用于判断应该使用哪个ORM模型,默认为"surveydata"
- """
- features = []
-
- # 导入所需的ORM模型
- from ..models.orm_models import UnitCeil, Surveydatum, FiftyThousandSurveyDatum
-
- # 根据表名获取对应的ORM模型和几何字段名
- model_mapping = {
- "surveydata": (Surveydatum, "geom"),
- "unit_ceil": (UnitCeil, "geom"),
- "fifty_thousand_survey_data": (FiftyThousandSurveyDatum, "geom")
- }
-
- # 获取对应的模型和几何字段名
- model_class, geom_field = model_mapping.get(table_name, (Surveydatum, "geom"))
-
- # 检查数据类型
- is_orm_object = len(vector_data_list) == 0 or hasattr(vector_data_list[0], '__table__')
-
- for vector_data in vector_data_list:
- # 构建包含所有列数据的字典
- data_dict = {}
-
- if is_orm_object:
- # 如果是ORM对象,使用模型的列获取数据
- columns = [column.name for column in model_class.__table__.columns]
- for column in columns:
- if hasattr(vector_data, column):
- value = getattr(vector_data, column)
- # 如果值是字符串且可能是 JSON,尝试解析
- if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
- try:
- value = json.loads(value)
- except:
- pass
- # 跳过几何字段,后续单独处理
- if column != geom_field:
- data_dict[column] = value
- else:
- # 如果是SQLAlchemy行对象,获取所有键
- for key in vector_data.keys():
- if key != geom_field:
- value = vector_data[key]
- # 如果值是字符串且可能是 JSON,尝试解析
- if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
- try:
- value = json.loads(value)
- except:
- pass
- data_dict[key] = value
- # 解析几何字段为GeoJSON格式的geometry
- geometry = None
- geom_value = None
-
- if is_orm_object and hasattr(vector_data, geom_field):
- geom_value = getattr(vector_data, geom_field)
- elif not is_orm_object and geom_field in vector_data.keys():
- geom_value = vector_data[geom_field]
-
- if geom_value:
- geometry = parse_geom_field(geom_value)
- # 创建Feature
- feature = {
- "type": "Feature",
- "properties": data_dict,
- "geometry": geometry
- }
- features.append(feature)
- # 创建GeoJSON对象
- geojson = {
- "type": "FeatureCollection",
- "features": features
- }
- # 创建临时目录
- temp_dir = tempfile.mkdtemp()
- # 生成文件名(使用时间戳避免重名)
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = f"{base_filename}_{timestamp}.geojson"
- file_path = os.path.join(temp_dir, filename)
- # 保存到文件,使用自定义编码器处理Decimal类型
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
- return {
- "message": "数据导出成功",
- "file_path": file_path,
- "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理
- "data": geojson
- }
-
|