vector_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from fastapi import HTTPException, UploadFile
  2. from sqlalchemy.orm import Session
  3. from ..models.vector import VectorData
  4. import json
  5. import os
  6. import datetime # 统一使用模块导入,避免冲突
  7. from decimal import Decimal
  8. from typing import List
  9. import uuid
  10. import tempfile
  11. import struct
  12. from sqlalchemy.sql import text
  13. import binascii
  14. # 导入shapely库用于解析WKB
  15. try:
  16. from shapely import wkb
  17. from shapely.geometry import mapping
  18. SHAPELY_AVAILABLE = True
  19. except ImportError:
  20. SHAPELY_AVAILABLE = False
  21. class DecimalEncoder(json.JSONEncoder):
  22. """处理Decimal类型的JSON编码器"""
  23. def default(self, obj):
  24. if isinstance(obj, Decimal):
  25. return float(obj)
  26. return super(DecimalEncoder, self).default(obj)
  27. def get_vector_data(db: Session, vector_id: int):
  28. """通过ID获取一条矢量数据记录"""
  29. vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
  30. if not vector_data:
  31. raise HTTPException(status_code=404, detail="矢量数据不存在")
  32. # 手动构建返回字典
  33. result = {}
  34. for column in vector_data.__table__.columns:
  35. value = getattr(vector_data, column.name)
  36. # 处理特殊类型
  37. if isinstance(value, Decimal):
  38. value = float(value)
  39. elif isinstance(value, datetime.datetime): # 修复datetime判断
  40. value = value.isoformat()
  41. elif str(column.type).startswith('geometry'):
  42. # 如果是几何类型,直接使用字符串表示
  43. if value is not None:
  44. value = str(value)
  45. result[column.name] = value
  46. return result
  47. def get_vector_data_batch(db: Session, vector_ids: List[int]):
  48. """批量获取矢量数据记录"""
  49. vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
  50. if not vector_data_list:
  51. raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
  52. result = []
  53. for vector_data in vector_data_list:
  54. item = {}
  55. for column in vector_data.__table__.columns:
  56. value = getattr(vector_data, column.name)
  57. # 处理特殊类型
  58. if isinstance(value, Decimal):
  59. value = float(value)
  60. elif isinstance(value, datetime.datetime): # 修复datetime判断
  61. value = value.isoformat()
  62. elif str(column.type).startswith('geometry'):
  63. if value is not None:
  64. value = str(value)
  65. item[column.name] = value
  66. result.append(item)
  67. return result
  68. async def import_vector_data(file: UploadFile, db: Session) -> dict:
  69. """导入GeoJSON文件到数据库"""
  70. try:
  71. # 读取文件内容
  72. content = await file.read()
  73. data = json.loads(content)
  74. # 验证GeoJSON格式
  75. if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
  76. raise ValueError("无效的GeoJSON格式")
  77. features = data.get("features", [])
  78. if not features:
  79. raise ValueError("GeoJSON文件中没有要素数据")
  80. # 获取表的所有列名
  81. columns = [column.name for column in VectorData.__table__.columns]
  82. # 导入每个要素
  83. imported_count = 0
  84. for feature in features:
  85. if not isinstance(feature, dict) or feature.get("type") != "Feature":
  86. continue
  87. # 获取属性
  88. properties = feature.get("properties", {})
  89. # 创建新记录
  90. vector_data = VectorData()
  91. # 设置每个字段的值(除了id)
  92. for column in columns:
  93. if column == 'id': # 跳过id字段
  94. continue
  95. if column in properties:
  96. value = properties[column]
  97. # 如果值是字典或列表,转换为JSON字符串
  98. if isinstance(value, (dict, list)):
  99. value = json.dumps(value, ensure_ascii=False)
  100. setattr(vector_data, column, value)
  101. # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
  102. geometry = feature.get("geometry")
  103. if geometry:
  104. geometry_str = json.dumps(geometry, ensure_ascii=False)
  105. setattr(vector_data, 'geom', geometry_str)
  106. elif 'geom' in properties:
  107. setattr(vector_data, 'geometry', properties['geom'])
  108. try:
  109. db.add(vector_data)
  110. imported_count += 1
  111. except Exception as e:
  112. continue
  113. # 提交事务
  114. try:
  115. db.commit()
  116. except Exception as e:
  117. db.rollback()
  118. raise ValueError(f"数据库操作失败: {str(e)}")
  119. return {
  120. "message": f"成功导入 {imported_count} 条记录",
  121. "imported_count": imported_count
  122. }
  123. except json.JSONDecodeError as e:
  124. raise ValueError(f"无效的JSON格式: {str(e)}")
  125. except Exception as e:
  126. db.rollback()
  127. raise ValueError(f"导入失败: {str(e)}")
  128. def export_vector_data(db: Session, vector_id: int):
  129. """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
  130. vector_data = get_vector_data(db, vector_id)
  131. return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata")
  132. def export_vector_data_batch(db: Session, vector_ids: List[int]):
  133. """批量导出矢量数据为GeoJSON格式并保存到文件"""
  134. vector_data_list = get_vector_data_batch(db, vector_ids)
  135. return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata")
  136. def export_all_vector_data(
  137. db: Session,
  138. table_name: str = "surveydata",
  139. export_format: str = "geojson" # 新增:格式参数,默认geojson
  140. ):
  141. """导出指定表数据为GeoJSON或普通JSON格式"""
  142. try:
  143. # 1. 查询表数据(保持原有逻辑)
  144. query = text(f'SELECT * FROM "{table_name}"')
  145. result = db.execute(query)
  146. columns = [col.name for col in result.cursor.description]
  147. # 转换为字典列表(原生查询结果处理)
  148. data_list = [dict(zip(columns, row)) for row in result.fetchall()]
  149. # 2. 根据格式生成不同文件
  150. if export_format == "geojson":
  151. # 原有逻辑:生成GeoJSON(依赖你的build_geojson工具)
  152. return _export_vector_data_to_file(data_list, f"export_{table_name}", table_name)
  153. else:
  154. # 新增逻辑:生成普通JSON
  155. return _export_json_data_to_file(data_list, f"export_{table_name}", table_name)
  156. except Exception as e:
  157. raise HTTPException(
  158. status_code=500,
  159. detail=f"查询表{table_name}失败:{str(e)}"
  160. )
  161. # 新增函数:生成普通JSON文件
  162. def _export_json_data_to_file(data_list, base_filename, table_name):
  163. temp_dir = tempfile.mkdtemp()
  164. file_path = os.path.join(temp_dir, f"{base_filename}.json") # 后缀改为json
  165. # 直接写入原始数据(无需GeoJSON转换)
  166. with open(file_path, "w", encoding="utf-8") as f:
  167. json.dump(data_list, f, ensure_ascii=False, indent=2)
  168. return {"file_path": file_path, "temp_dir": temp_dir}
  169. def parse_geom_field(geom_value) -> dict:
  170. """解析 geom 字段为 GeoJSON 格式的 geometry"""
  171. try:
  172. geom_str = str(geom_value)
  173. # 处理PostGIS WKB格式的点数据
  174. if geom_str and geom_str.startswith('0101000020'):
  175. binary_geom = bytes.fromhex(geom_str[2:])
  176. byte_order = struct.unpack('B', binary_geom[:1])[0]
  177. endian = '<' if byte_order == 1 else '>'
  178. if len(binary_geom) >= 1 + 16:
  179. coord_bytes = binary_geom[-16:]
  180. x, y = struct.unpack(f'{endian}dd', coord_bytes)
  181. return {
  182. "type": "Point",
  183. "coordinates": [x, y]
  184. }
  185. else:
  186. print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}")
  187. # 处理WKT格式
  188. elif geom_str and (geom_str.startswith('POINT') or
  189. geom_str.startswith('LINESTRING') or
  190. geom_str.startswith('POLYGON') or
  191. geom_str.startswith('MULTIPOINT') or
  192. geom_str.startswith('MULTILINESTRING') or
  193. geom_str.startswith('MULTIPOLYGON')):
  194. print(f"检测到WKT格式几何数据: {geom_str[:30]}...")
  195. return None
  196. # 处理EWKT格式
  197. elif geom_str and geom_str.startswith('SRID='):
  198. print(f"检测到EWKT格式几何数据: {geom_str[:30]}...")
  199. return None
  200. # 处理十六进制WKB格式
  201. elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str):
  202. print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...")
  203. if SHAPELY_AVAILABLE:
  204. try:
  205. if len(geom_str) % 2 != 0:
  206. geom_str = '0' + geom_str
  207. binary_data = binascii.unhexlify(geom_str)
  208. shape = wkb.loads(binary_data)
  209. return mapping(shape)
  210. except Exception as e:
  211. print(f"Shapely解析WKB失败: {e}")
  212. try:
  213. from ..database import engine
  214. with engine.connect() as connection:
  215. query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson")
  216. result = connection.execute(query, [geom_str]).fetchone()
  217. if result and result.geojson:
  218. return json.loads(result.geojson)
  219. except Exception as e:
  220. print(f"使用PostgreSQL解析WKB失败: {e}")
  221. return None
  222. else:
  223. from ..database import engine
  224. try:
  225. with engine.connect() as connection:
  226. query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson")
  227. result = connection.execute(query, [geom_value]).fetchone()
  228. if result and result.geojson:
  229. return json.loads(result.geojson)
  230. except Exception as e:
  231. print(f"使用ST_AsGeoJSON转换几何数据失败: {e}")
  232. print(f"未识别的几何数据格式: {geom_str[:50]}...")
  233. except (ValueError, IndexError, struct.error) as e:
  234. print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}")
  235. return None
  236. def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"):
  237. """将矢量数据列表导出为 GeoJSON 文件(修复后)"""
  238. try:
  239. features = []
  240. for data in vector_data_list:
  241. # 1. 处理空间信息(从longitude和latitude生成点坐标)
  242. longitude = data.get('longitude')
  243. latitude = data.get('latitude')
  244. # 跳过经纬度无效的数据
  245. if not (isinstance(longitude, (int, float)) and isinstance(latitude, (int, float))):
  246. print(f"跳过无效数据:经度={longitude}(类型{type(longitude)}),纬度={latitude}(类型{type(latitude)})")
  247. continue
  248. # 生成标准GeoJSON点 geometry
  249. geometry = {
  250. "type": "Point",
  251. "coordinates": [longitude, latitude] # [经度, 纬度]
  252. }
  253. # 2. 处理属性信息(处理特殊类型)
  254. properties = {}
  255. for key, value in data.items():
  256. # 处理日期时间类型
  257. if isinstance(value, datetime.datetime):
  258. properties[key] = value.strftime("%Y-%m-%d %H:%M:%S")
  259. # 处理'nan'特殊值
  260. elif value == 'nan':
  261. properties[key] = None
  262. # 处理Decimal类型
  263. elif isinstance(value, Decimal):
  264. properties[key] = float(value)
  265. # 处理其他值
  266. else:
  267. properties[key] = value
  268. # 3. 组合成GeoJSON要素
  269. feature = {
  270. "type": "Feature",
  271. "geometry": geometry,
  272. "properties": properties
  273. }
  274. features.append(feature)
  275. # 4. 生成完整GeoJSON
  276. geojson_data = {
  277. "type": "FeatureCollection",
  278. "features": features
  279. }
  280. # 创建临时目录并保存文件
  281. temp_dir = tempfile.mkdtemp()
  282. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  283. filename = f"{base_filename}_{timestamp}.geojson"
  284. file_path = os.path.join(temp_dir, filename)
  285. # 写入文件(使用自定义编码器处理Decimal)
  286. with open(file_path, "w", encoding="utf-8") as f:
  287. json.dump(geojson_data, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
  288. # 返回结果
  289. return {
  290. "message": "数据导出成功",
  291. "file_path": file_path,
  292. "temp_dir": temp_dir,
  293. "data": geojson_data
  294. }
  295. except Exception as e:
  296. error_data = data if 'data' in locals() else "未知数据"
  297. print(f"生成矢量数据时出错:{str(e)},出错数据:{error_data}")
  298. raise