vector_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from fastapi import HTTPException, UploadFile
  2. from sqlalchemy.orm import Session
  3. from ..models.vector import VectorData
  4. import json
  5. import os
  6. import datetime # 统一使用模块导入,避免冲突
  7. from decimal import Decimal
  8. from typing import List
  9. import uuid
  10. import tempfile
  11. import struct
  12. from sqlalchemy.sql import text
  13. import binascii
  14. # 导入shapely库用于解析WKB
  15. try:
  16. from shapely import wkb
  17. from shapely.geometry import mapping
  18. SHAPELY_AVAILABLE = True
  19. except ImportError:
  20. SHAPELY_AVAILABLE = False
  21. class DecimalEncoder(json.JSONEncoder):
  22. """处理Decimal类型的JSON编码器"""
  23. def default(self, obj):
  24. if isinstance(obj, Decimal):
  25. return float(obj)
  26. return super(DecimalEncoder, self).default(obj)
  27. def get_vector_data(db: Session, vector_id: int):
  28. """通过ID获取一条矢量数据记录"""
  29. vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
  30. if not vector_data:
  31. raise HTTPException(status_code=404, detail="矢量数据不存在")
  32. # 手动构建返回字典
  33. result = {}
  34. for column in vector_data.__table__.columns:
  35. value = getattr(vector_data, column.name)
  36. # 处理特殊类型
  37. if isinstance(value, Decimal):
  38. value = float(value)
  39. elif isinstance(value, datetime.datetime): # 修复datetime判断
  40. value = value.isoformat()
  41. elif str(column.type).startswith('geometry'):
  42. # 如果是几何类型,直接使用字符串表示
  43. if value is not None:
  44. value = str(value)
  45. result[column.name] = value
  46. return result
  47. def get_vector_data_batch(db: Session, vector_ids: List[int]):
  48. """批量获取矢量数据记录"""
  49. vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
  50. if not vector_data_list:
  51. raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
  52. result = []
  53. for vector_data in vector_data_list:
  54. item = {}
  55. for column in vector_data.__table__.columns:
  56. value = getattr(vector_data, column.name)
  57. # 处理特殊类型
  58. if isinstance(value, Decimal):
  59. value = float(value)
  60. elif isinstance(value, datetime.datetime): # 修复datetime判断
  61. value = value.isoformat()
  62. elif str(column.type).startswith('geometry'):
  63. if value is not None:
  64. value = str(value)
  65. item[column.name] = value
  66. result.append(item)
  67. return result
  68. async def import_vector_data(file: UploadFile, db: Session) -> dict:
  69. """导入GeoJSON文件到数据库"""
  70. try:
  71. # 读取文件内容
  72. content = await file.read()
  73. data = json.loads(content)
  74. # 验证GeoJSON格式
  75. if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
  76. raise ValueError("无效的GeoJSON格式")
  77. features = data.get("features", [])
  78. if not features:
  79. raise ValueError("GeoJSON文件中没有要素数据")
  80. # 获取表的所有列名
  81. columns = [column.name for column in VectorData.__table__.columns]
  82. # 导入每个要素
  83. imported_count = 0
  84. for feature in features:
  85. if not isinstance(feature, dict) or feature.get("type") != "Feature":
  86. continue
  87. # 获取属性
  88. properties = feature.get("properties", {})
  89. # 创建新记录
  90. vector_data = VectorData()
  91. # 设置每个字段的值(除了id)
  92. for column in columns:
  93. if column == 'id': # 跳过id字段
  94. continue
  95. if column in properties:
  96. value = properties[column]
  97. # 如果值是字典或列表,转换为JSON字符串
  98. if isinstance(value, (dict, list)):
  99. value = json.dumps(value, ensure_ascii=False)
  100. setattr(vector_data, column, value)
  101. # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
  102. geometry = feature.get("geometry")
  103. if geometry:
  104. geometry_str = json.dumps(geometry, ensure_ascii=False)
  105. setattr(vector_data, 'geom', geometry_str)
  106. elif 'geom' in properties:
  107. setattr(vector_data, 'geometry', properties['geom'])
  108. try:
  109. db.add(vector_data)
  110. imported_count += 1
  111. except Exception as e:
  112. continue
  113. # 提交事务
  114. try:
  115. db.commit()
  116. except Exception as e:
  117. db.rollback()
  118. raise ValueError(f"数据库操作失败: {str(e)}")
  119. return {
  120. "message": f"成功导入 {imported_count} 条记录",
  121. "imported_count": imported_count
  122. }
  123. except json.JSONDecodeError as e:
  124. raise ValueError(f"无效的JSON格式: {str(e)}")
  125. except Exception as e:
  126. db.rollback()
  127. raise ValueError(f"导入失败: {str(e)}")
  128. def export_vector_data(db: Session, vector_id: int):
  129. """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
  130. vector_data = get_vector_data(db, vector_id)
  131. return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata")
  132. def export_vector_data_batch(db: Session, vector_ids: List[int]):
  133. """批量导出矢量数据为GeoJSON格式并保存到文件"""
  134. vector_data_list = get_vector_data_batch(db, vector_ids)
  135. return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata")
  136. def export_all_vector_data(db: Session, table_name: str = "surveydata"):
  137. """导出指定的矢量数据表为GeoJSON格式并保存到文件"""
  138. try:
  139. query = text(f'SELECT * FROM "{table_name}"')
  140. result = db.execute(query)
  141. columns = [col.name for col in result.cursor.description]
  142. vector_data_list = [dict(zip(columns, row)) for row in result.fetchall()] # 类型修正
  143. return _export_vector_data_to_file(vector_data_list, f"export_{table_name}", table_name)
  144. except Exception as e:
  145. raise HTTPException(
  146. status_code=500,
  147. detail=f"查询表{table_name}失败:{str(e)}"
  148. )
  149. def parse_geom_field(geom_value) -> dict:
  150. """解析 geom 字段为 GeoJSON 格式的 geometry"""
  151. try:
  152. geom_str = str(geom_value)
  153. # 处理PostGIS WKB格式的点数据
  154. if geom_str and geom_str.startswith('0101000020'):
  155. binary_geom = bytes.fromhex(geom_str[2:])
  156. byte_order = struct.unpack('B', binary_geom[:1])[0]
  157. endian = '<' if byte_order == 1 else '>'
  158. if len(binary_geom) >= 1 + 16:
  159. coord_bytes = binary_geom[-16:]
  160. x, y = struct.unpack(f'{endian}dd', coord_bytes)
  161. return {
  162. "type": "Point",
  163. "coordinates": [x, y]
  164. }
  165. else:
  166. print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}")
  167. # 处理WKT格式
  168. elif geom_str and (geom_str.startswith('POINT') or
  169. geom_str.startswith('LINESTRING') or
  170. geom_str.startswith('POLYGON') or
  171. geom_str.startswith('MULTIPOINT') or
  172. geom_str.startswith('MULTILINESTRING') or
  173. geom_str.startswith('MULTIPOLYGON')):
  174. print(f"检测到WKT格式几何数据: {geom_str[:30]}...")
  175. return None
  176. # 处理EWKT格式
  177. elif geom_str and geom_str.startswith('SRID='):
  178. print(f"检测到EWKT格式几何数据: {geom_str[:30]}...")
  179. return None
  180. # 处理十六进制WKB格式
  181. elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str):
  182. print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...")
  183. if SHAPELY_AVAILABLE:
  184. try:
  185. if len(geom_str) % 2 != 0:
  186. geom_str = '0' + geom_str
  187. binary_data = binascii.unhexlify(geom_str)
  188. shape = wkb.loads(binary_data)
  189. return mapping(shape)
  190. except Exception as e:
  191. print(f"Shapely解析WKB失败: {e}")
  192. try:
  193. from ..database import engine
  194. with engine.connect() as connection:
  195. query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson")
  196. result = connection.execute(query, [geom_str]).fetchone()
  197. if result and result.geojson:
  198. return json.loads(result.geojson)
  199. except Exception as e:
  200. print(f"使用PostgreSQL解析WKB失败: {e}")
  201. return None
  202. else:
  203. from ..database import engine
  204. try:
  205. with engine.connect() as connection:
  206. query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson")
  207. result = connection.execute(query, [geom_value]).fetchone()
  208. if result and result.geojson:
  209. return json.loads(result.geojson)
  210. except Exception as e:
  211. print(f"使用ST_AsGeoJSON转换几何数据失败: {e}")
  212. print(f"未识别的几何数据格式: {geom_str[:50]}...")
  213. except (ValueError, IndexError, struct.error) as e:
  214. print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}")
  215. return None
  216. def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"):
  217. """将矢量数据列表导出为 GeoJSON 文件(修复后)"""
  218. try:
  219. features = []
  220. for data in vector_data_list:
  221. # 1. 处理空间信息(从longitude和latitude生成点坐标)
  222. longitude = data.get('longitude')
  223. latitude = data.get('latitude')
  224. # 跳过经纬度无效的数据
  225. if not (isinstance(longitude, (int, float)) and isinstance(latitude, (int, float))):
  226. print(f"跳过无效数据:经度={longitude}(类型{type(longitude)}),纬度={latitude}(类型{type(latitude)})")
  227. continue
  228. # 生成标准GeoJSON点 geometry
  229. geometry = {
  230. "type": "Point",
  231. "coordinates": [longitude, latitude] # [经度, 纬度]
  232. }
  233. # 2. 处理属性信息(处理特殊类型)
  234. properties = {}
  235. for key, value in data.items():
  236. # 处理日期时间类型
  237. if isinstance(value, datetime.datetime):
  238. properties[key] = value.strftime("%Y-%m-%d %H:%M:%S")
  239. # 处理'nan'特殊值
  240. elif value == 'nan':
  241. properties[key] = None
  242. # 处理Decimal类型
  243. elif isinstance(value, Decimal):
  244. properties[key] = float(value)
  245. # 处理其他值
  246. else:
  247. properties[key] = value
  248. # 3. 组合成GeoJSON要素
  249. feature = {
  250. "type": "Feature",
  251. "geometry": geometry,
  252. "properties": properties
  253. }
  254. features.append(feature)
  255. # 4. 生成完整GeoJSON
  256. geojson_data = {
  257. "type": "FeatureCollection",
  258. "features": features
  259. }
  260. # 创建临时目录并保存文件
  261. temp_dir = tempfile.mkdtemp()
  262. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  263. filename = f"{base_filename}_{timestamp}.geojson"
  264. file_path = os.path.join(temp_dir, filename)
  265. # 写入文件(使用自定义编码器处理Decimal)
  266. with open(file_path, "w", encoding="utf-8") as f:
  267. json.dump(geojson_data, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
  268. # 返回结果
  269. return {
  270. "message": "数据导出成功",
  271. "file_path": file_path,
  272. "temp_dir": temp_dir,
  273. "data": geojson_data
  274. }
  275. except Exception as e:
  276. error_data = data if 'data' in locals() else "未知数据"
  277. print(f"生成矢量数据时出错:{str(e)},出错数据:{error_data}")
  278. raise