vector_service.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # 矢量数据服务
  2. from fastapi import HTTPException, UploadFile
  3. from sqlalchemy.orm import Session
  4. from ..models.vector import VectorData
  5. import json
  6. import os
  7. from datetime import datetime
  8. from decimal import Decimal
  9. from typing import List
  10. import uuid
  11. import tempfile
  12. class DecimalEncoder(json.JSONEncoder):
  13. def default(self, obj):
  14. if isinstance(obj, Decimal):
  15. return float(obj)
  16. return super(DecimalEncoder, self).default(obj)
  17. def get_vector_data(db: Session, vector_id: int):
  18. """通过ID获取一条矢量数据记录"""
  19. vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
  20. if not vector_data:
  21. raise HTTPException(status_code=404, detail="矢量数据不存在")
  22. return vector_data
  23. def get_vector_data_batch(db: Session, vector_ids: List[int]):
  24. """批量获取矢量数据记录"""
  25. vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
  26. if not vector_data_list:
  27. raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
  28. return vector_data_list
  29. async def import_vector_data(file: UploadFile, db: Session) -> dict:
  30. """导入GeoJSON文件到数据库"""
  31. try:
  32. # 读取文件内容
  33. content = await file.read()
  34. data = json.loads(content)
  35. # 验证GeoJSON格式
  36. if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
  37. raise ValueError("无效的GeoJSON格式")
  38. features = data.get("features", [])
  39. if not features:
  40. raise ValueError("GeoJSON文件中没有要素数据")
  41. # 获取表的所有列名
  42. columns = [column.name for column in VectorData.__table__.columns]
  43. # 导入每个要素
  44. imported_count = 0
  45. for feature in features:
  46. if not isinstance(feature, dict) or feature.get("type") != "Feature":
  47. continue
  48. # 获取属性
  49. properties = feature.get("properties", {})
  50. # 创建新记录
  51. vector_data = VectorData()
  52. # 设置每个字段的值(除了id)
  53. for column in columns:
  54. if column == 'id': # 跳过id字段
  55. continue
  56. if column in properties:
  57. value = properties[column]
  58. # 如果值是字典或列表,转换为JSON字符串
  59. if isinstance(value, (dict, list)):
  60. value = json.dumps(value, ensure_ascii=False)
  61. setattr(vector_data, column, value)
  62. # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
  63. geometry = feature.get("geometry")
  64. if geometry:
  65. geometry_str = json.dumps(geometry, ensure_ascii=False)
  66. setattr(vector_data, 'geometry', geometry_str)
  67. elif 'geom' in properties:
  68. setattr(vector_data, 'geometry', properties['geom'])
  69. try:
  70. db.add(vector_data)
  71. imported_count += 1
  72. except Exception as e:
  73. continue
  74. # 提交事务
  75. try:
  76. db.commit()
  77. except Exception as e:
  78. db.rollback()
  79. raise ValueError(f"数据库操作失败: {str(e)}")
  80. return {
  81. "message": f"成功导入 {imported_count} 条记录",
  82. "imported_count": imported_count
  83. }
  84. except json.JSONDecodeError as e:
  85. raise ValueError(f"无效的JSON格式: {str(e)}")
  86. except Exception as e:
  87. db.rollback()
  88. raise ValueError(f"导入失败: {str(e)}")
  89. def export_vector_data(db: Session, vector_id: int):
  90. """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
  91. vector_data = get_vector_data(db, vector_id)
  92. return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
  93. def export_vector_data_batch(db: Session, vector_ids: List[int]):
  94. """批量导出矢量数据为GeoJSON格式并保存到文件"""
  95. vector_data_list = get_vector_data_batch(db, vector_ids)
  96. return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
  97. def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
  98. """将矢量数据列表导出为GeoJSON文件"""
  99. features = []
  100. for vector_data in vector_data_list:
  101. # 获取表的所有列名
  102. columns = [column.name for column in VectorData.__table__.columns]
  103. # 构建包含所有列数据的字典
  104. data_dict = {}
  105. for column in columns:
  106. value = getattr(vector_data, column)
  107. # 如果值是字符串且可能是JSON,尝试解析
  108. if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
  109. try:
  110. value = json.loads(value)
  111. except:
  112. pass
  113. data_dict[column] = value
  114. # 创建Feature
  115. feature = {
  116. "type": "Feature",
  117. "properties": data_dict,
  118. "geometry": json.loads(vector_data.geometry) if hasattr(vector_data, 'geometry') else None
  119. }
  120. features.append(feature)
  121. # 创建GeoJSON对象
  122. geojson = {
  123. "type": "FeatureCollection",
  124. "features": features
  125. }
  126. # 创建临时目录
  127. temp_dir = tempfile.mkdtemp()
  128. # 生成文件名(使用时间戳避免重名)
  129. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  130. filename = f"{base_filename}_{timestamp}.geojson"
  131. file_path = os.path.join(temp_dir, filename)
  132. # 保存到文件,使用自定义编码器处理Decimal类型
  133. with open(file_path, "w", encoding="utf-8") as f:
  134. json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
  135. return {
  136. "message": "数据导出成功",
  137. "file_path": file_path,
  138. "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理
  139. "data": geojson
  140. }