vector_service.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from fastapi import HTTPException, UploadFile
  2. from sqlalchemy.orm import Session
  3. from ..models.vector import VectorData
  4. import json
  5. import os
  6. from datetime import datetime
  7. from decimal import Decimal
  8. from typing import List
  9. import uuid
  10. import tempfile
  11. import struct
  12. class DecimalEncoder(json.JSONEncoder):
  13. def default(self, obj):
  14. if isinstance(obj, Decimal):
  15. return float(obj)
  16. return super(DecimalEncoder, self).default(obj)
  17. def get_vector_data(db: Session, vector_id: int):
  18. """通过ID获取一条矢量数据记录"""
  19. vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
  20. if not vector_data:
  21. raise HTTPException(status_code=404, detail="矢量数据不存在")
  22. return vector_data
  23. def get_vector_data_batch(db: Session, vector_ids: List[int]):
  24. """批量获取矢量数据记录"""
  25. vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
  26. if not vector_data_list:
  27. raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
  28. return vector_data_list
  29. async def import_vector_data(file: UploadFile, db: Session) -> dict:
  30. """导入GeoJSON文件到数据库"""
  31. try:
  32. # 读取文件内容
  33. content = await file.read()
  34. data = json.loads(content)
  35. # 验证GeoJSON格式
  36. if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
  37. raise ValueError("无效的GeoJSON格式")
  38. features = data.get("features", [])
  39. if not features:
  40. raise ValueError("GeoJSON文件中没有要素数据")
  41. # 获取表的所有列名
  42. columns = [column.name for column in VectorData.__table__.columns]
  43. # 导入每个要素
  44. imported_count = 0
  45. for feature in features:
  46. if not isinstance(feature, dict) or feature.get("type") != "Feature":
  47. continue
  48. # 获取属性
  49. properties = feature.get("properties", {})
  50. # 创建新记录
  51. vector_data = VectorData()
  52. # 设置每个字段的值(除了id)
  53. for column in columns:
  54. if column == 'id': # 跳过id字段
  55. continue
  56. if column in properties:
  57. value = properties[column]
  58. # 如果值是字典或列表,转换为JSON字符串
  59. if isinstance(value, (dict, list)):
  60. value = json.dumps(value, ensure_ascii=False)
  61. setattr(vector_data, column, value)
  62. # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
  63. geometry = feature.get("geometry")
  64. if geometry:
  65. geometry_str = json.dumps(geometry, ensure_ascii=False)
  66. setattr(vector_data, 'geometry', geometry_str)
  67. elif 'geom' in properties:
  68. setattr(vector_data, 'geometry', properties['geom'])
  69. try:
  70. db.add(vector_data)
  71. imported_count += 1
  72. except Exception as e:
  73. continue
  74. # 提交事务
  75. try:
  76. db.commit()
  77. except Exception as e:
  78. db.rollback()
  79. raise ValueError(f"数据库操作失败: {str(e)}")
  80. return {
  81. "message": f"成功导入 {imported_count} 条记录",
  82. "imported_count": imported_count
  83. }
  84. except json.JSONDecodeError as e:
  85. raise ValueError(f"无效的JSON格式: {str(e)}")
  86. except Exception as e:
  87. db.rollback()
  88. raise ValueError(f"导入失败: {str(e)}")
  89. def export_vector_data(db: Session, vector_id: int):
  90. """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
  91. vector_data = get_vector_data(db, vector_id)
  92. return _export_vector_data_to_file([vector_data], f"export_{vector_id}")
  93. def export_vector_data_batch(db: Session, vector_ids: List[int]):
  94. """批量导出矢量数据为GeoJSON格式并保存到文件"""
  95. vector_data_list = get_vector_data_batch(db, vector_ids)
  96. return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}")
  97. def export_all_vector_data(db: Session):
  98. """导出整个矢量数据表为GeoJSON格式并保存到文件"""
  99. # 查询所有矢量数据
  100. vector_data_list = db.query(VectorData).all()
  101. # 如果没有数据,抛出异常
  102. if not vector_data_list:
  103. raise HTTPException(status_code=404, detail="数据库中没有矢量数据")
  104. # 调用现有的导出函数
  105. return _export_vector_data_to_file(vector_data_list, "export_all")
  106. def parse_geom_field(geom_value) -> dict:
  107. """
  108. 解析 geom 字段为 GeoJSON 格式的 geometry。
  109. 如果解析失败,返回 None。
  110. """
  111. try:
  112. # 将 geom_value 转换为字符串
  113. geom_str = str(geom_value)
  114. if geom_str and geom_str.startswith('0101000020'):
  115. # 去掉前两个字符(字节序标记),并转换为字节对象
  116. binary_geom = bytes.fromhex(geom_str[2:])
  117. # 解析字节序(前1个字节)
  118. byte_order = struct.unpack('B', binary_geom[:1])[0]
  119. endian = '<' if byte_order == 1 else '>'
  120. # 检查数据长度是否足够解析坐标
  121. if len(binary_geom) >= 1 + 16:
  122. # 从数据末尾往前找 16 字节作为坐标
  123. coord_bytes = binary_geom[-16:]
  124. x, y = struct.unpack(f'{endian}dd', coord_bytes)
  125. return {
  126. "type": "Point",
  127. "coordinates": [x, y]
  128. }
  129. else:
  130. print(f"Insufficient data in geom value: {geom_str}. Length: {len(binary_geom)}")
  131. except (ValueError, IndexError, struct.error) as e:
  132. print(f"Failed to parse geom field: {geom_str if 'geom_str' in locals() else geom_value}. Error: {e}")
  133. return None
  134. def _export_vector_data_to_file(vector_data_list: List[VectorData], base_filename: str):
  135. """将矢量数据列表导出为 GeoJSON 文件"""
  136. features = []
  137. for vector_data in vector_data_list:
  138. # 获取表的所有列名
  139. columns = [column.name for column in VectorData.__table__.columns]
  140. # 构建包含所有列数据的字典
  141. data_dict = {}
  142. for column in columns:
  143. value = getattr(vector_data, column)
  144. # 如果值是字符串且可能是 JSON,尝试解析
  145. if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
  146. try:
  147. value = json.loads(value)
  148. except:
  149. pass
  150. # 跳过 geom 字段,后续单独处理
  151. if column != 'geom':
  152. data_dict[column] = value
  153. # 解析 geom 字段为 GeoJSON 格式的 geometry
  154. geometry = None
  155. if hasattr(vector_data, 'geom'):
  156. geom_value = getattr(vector_data, 'geom')
  157. geometry = parse_geom_field(geom_value)
  158. # 创建 Feature
  159. feature = {
  160. "type": "Feature",
  161. "properties": data_dict,
  162. "geometry": geometry
  163. }
  164. features.append(feature)
  165. # 创建 GeoJSON 对象
  166. geojson = {
  167. "type": "FeatureCollection",
  168. "features": features
  169. }
  170. # 创建临时目录
  171. temp_dir = tempfile.mkdtemp()
  172. # 生成文件名(使用时间戳避免重名)
  173. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  174. filename = f"{base_filename}_{timestamp}.geojson"
  175. file_path = os.path.join(temp_dir, filename)
  176. # 保存到文件,使用自定义编码器处理 Decimal 类型
  177. with open(file_path, "w", encoding="utf-8") as f:
  178. json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
  179. return {
  180. "message": "数据导出成功",
  181. "file_path": file_path,
  182. "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理
  183. "data": geojson
  184. }