123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- # 矢量数据服务
- from fastapi import HTTPException, UploadFile
- from sqlalchemy.orm import Session
- from ..models.vector import VectorData
- import json
- import os
- from datetime import datetime
- from decimal import Decimal
- from typing import List
- import uuid
- import tempfile
- class DecimalEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, Decimal):
- return float(obj)
- return super(DecimalEncoder, self).default(obj)
- def get_vector_data(db: Session, vector_id: int):
- """通过ID获取一条矢量数据记录"""
- vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
- if not vector_data:
- raise HTTPException(status_code=404, detail="矢量数据不存在")
-
- # 手动构建返回字典
- result = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime):
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- # 如果是几何类型,直接使用字符串表示
- if value is not None:
- value = str(value)
- result[column.name] = value
-
- return result
- def get_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量获取矢量数据记录"""
- vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
- if not vector_data_list:
- raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
-
- result = []
- for vector_data in vector_data_list:
- item = {}
- for column in vector_data.__table__.columns:
- value = getattr(vector_data, column.name)
- # 处理特殊类型
- if isinstance(value, Decimal):
- value = float(value)
- elif isinstance(value, datetime):
- value = value.isoformat()
- elif str(column.type).startswith('geometry'):
- # 如果是几何类型,直接使用字符串表示
- if value is not None:
- value = str(value)
- item[column.name] = value
- result.append(item)
-
- return result
- async def import_vector_data(file: UploadFile, db: Session) -> dict:
- """导入GeoJSON文件到数据库"""
- try:
- # 读取文件内容
- content = await file.read()
- data = json.loads(content)
-
- # 验证GeoJSON格式
- if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
- raise ValueError("无效的GeoJSON格式")
-
- features = data.get("features", [])
- if not features:
- raise ValueError("GeoJSON文件中没有要素数据")
-
- # 获取表的所有列名
- columns = [column.name for column in VectorData.__table__.columns]
-
- # 导入每个要素
- imported_count = 0
- for feature in features:
- if not isinstance(feature, dict) or feature.get("type") != "Feature":
- continue
-
- # 获取属性
- properties = feature.get("properties", {})
-
- # 创建新记录
- vector_data = VectorData()
-
- # 设置每个字段的值(除了id)
- for column in columns:
- if column == 'id': # 跳过id字段
- continue
- if column in properties:
- value = properties[column]
- # 如果值是字典或列表,转换为JSON字符串
- if isinstance(value, (dict, list)):
- value = json.dumps(value, ensure_ascii=False)
- setattr(vector_data, column, value)
-
- # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
- geometry = feature.get("geometry")
- if geometry:
- geometry_str = json.dumps(geometry, ensure_ascii=False)
- setattr(vector_data, 'geometry', geometry_str)
- elif 'geom' in properties:
- setattr(vector_data, 'geometry', properties['geom'])
-
- try:
- db.add(vector_data)
- imported_count += 1
- except Exception as e:
- continue
-
- # 提交事务
- try:
- db.commit()
- except Exception as e:
- db.rollback()
- raise ValueError(f"数据库操作失败: {str(e)}")
-
- return {
- "message": f"成功导入 {imported_count} 条记录",
- "imported_count": imported_count
- }
-
- except json.JSONDecodeError as e:
- raise ValueError(f"无效的JSON格式: {str(e)}")
- except Exception as e:
- db.rollback()
- raise ValueError(f"导入失败: {str(e)}")
- def export_vector_data(db: Session, vector_id: int):
- """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
- vector_data = get_vector_data(db, vector_id)
- return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
- def export_vector_data_batch(db: Session, vector_ids: List[int]):
- """批量导出矢量数据为GeoJSON格式并保存到文件"""
- vector_data_list = get_vector_data_batch(db, vector_ids)
- return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
- def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
- """将矢量数据列表导出为GeoJSON文件"""
- features = []
-
- for vector_data in vector_data_list:
- # 获取表的所有列名
- columns = [column.name for column in VectorData.__table__.columns]
-
- # 构建包含所有列数据的字典
- data_dict = {}
- for column in columns:
- value = getattr(vector_data, column)
- # 如果值是字符串且可能是JSON,尝试解析
- if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
- try:
- value = json.loads(value)
- except:
- pass
- data_dict[column] = value
-
- # 创建Feature
- feature = {
- "type": "Feature",
- "properties": data_dict,
- "geometry": json.loads(vector_data.geometry) if hasattr(vector_data, 'geometry') else None
- }
- features.append(feature)
-
- # 创建GeoJSON对象
- geojson = {
- "type": "FeatureCollection",
- "features": features
- }
-
- # 创建临时目录
- temp_dir = tempfile.mkdtemp()
-
- # 生成文件名(使用时间戳避免重名)
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = f"{base_filename}_{timestamp}.geojson"
- file_path = os.path.join(temp_dir, filename)
-
- # 保存到文件,使用自定义编码器处理Decimal类型
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
-
- return {
- "message": "数据导出成功",
- "file_path": file_path,
- "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理
- "data": geojson
- }
|