vector_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. from fastapi import HTTPException, UploadFile
  2. from sqlalchemy.orm import Session
  3. from ..models.vector import VectorData
  4. import json
  5. import os
  6. from datetime import datetime
  7. from decimal import Decimal
  8. from typing import List
  9. import uuid
  10. import tempfile
  11. import struct
  12. from sqlalchemy.sql import text
  13. import binascii
  14. # 导入shapely库用于解析WKB
  15. try:
  16. from shapely import wkb
  17. from shapely.geometry import mapping
  18. SHAPELY_AVAILABLE = True
  19. except ImportError:
  20. SHAPELY_AVAILABLE = False
  21. class DecimalEncoder(json.JSONEncoder):
  22. def default(self, obj):
  23. if isinstance(obj, Decimal):
  24. return float(obj)
  25. return super(DecimalEncoder, self).default(obj)
  26. def get_vector_data(db: Session, vector_id: int):
  27. """通过ID获取一条矢量数据记录"""
  28. vector_data = db.query(VectorData).filter(VectorData.id == vector_id).first()
  29. if not vector_data:
  30. raise HTTPException(status_code=404, detail="矢量数据不存在")
  31. # 手动构建返回字典
  32. result = {}
  33. for column in vector_data.__table__.columns:
  34. value = getattr(vector_data, column.name)
  35. # 处理特殊类型
  36. if isinstance(value, Decimal):
  37. value = float(value)
  38. elif isinstance(value, datetime):
  39. value = value.isoformat()
  40. elif str(column.type).startswith('geometry'):
  41. # 如果是几何类型,直接使用字符串表示
  42. if value is not None:
  43. value = str(value)
  44. result[column.name] = value
  45. return result
  46. def get_vector_data_batch(db: Session, vector_ids: List[int]):
  47. """批量获取矢量数据记录"""
  48. vector_data_list = db.query(VectorData).filter(VectorData.id.in_(vector_ids)).all()
  49. if not vector_data_list:
  50. raise HTTPException(status_code=404, detail="未找到指定的矢量数据")
  51. result = []
  52. for vector_data in vector_data_list:
  53. item = {}
  54. for column in vector_data.__table__.columns:
  55. value = getattr(vector_data, column.name)
  56. # 处理特殊类型
  57. if isinstance(value, Decimal):
  58. value = float(value)
  59. elif isinstance(value, datetime):
  60. value = value.isoformat()
  61. elif str(column.type).startswith('geometry'):
  62. # 如果是几何类型,直接使用字符串表示
  63. if value is not None:
  64. value = str(value)
  65. item[column.name] = value
  66. result.append(item)
  67. return result
  68. async def import_vector_data(file: UploadFile, db: Session) -> dict:
  69. """导入GeoJSON文件到数据库"""
  70. try:
  71. # 读取文件内容
  72. content = await file.read()
  73. data = json.loads(content)
  74. # 验证GeoJSON格式
  75. if not isinstance(data, dict) or data.get("type") != "FeatureCollection":
  76. raise ValueError("无效的GeoJSON格式")
  77. features = data.get("features", [])
  78. if not features:
  79. raise ValueError("GeoJSON文件中没有要素数据")
  80. # 获取表的所有列名
  81. columns = [column.name for column in VectorData.__table__.columns]
  82. # 导入每个要素
  83. imported_count = 0
  84. for feature in features:
  85. if not isinstance(feature, dict) or feature.get("type") != "Feature":
  86. continue
  87. # 获取属性
  88. properties = feature.get("properties", {})
  89. # 创建新记录
  90. vector_data = VectorData()
  91. # 设置每个字段的值(除了id)
  92. for column in columns:
  93. if column == 'id': # 跳过id字段
  94. continue
  95. if column in properties:
  96. value = properties[column]
  97. # 如果值是字典或列表,转换为JSON字符串
  98. if isinstance(value, (dict, list)):
  99. value = json.dumps(value, ensure_ascii=False)
  100. setattr(vector_data, column, value)
  101. # 设置几何数据(优先使用geometry字段,如果没有则使用geom字段)
  102. geometry = feature.get("geometry")
  103. if geometry:
  104. geometry_str = json.dumps(geometry, ensure_ascii=False)
  105. setattr(vector_data, 'geometry', geometry_str)
  106. elif 'geom' in properties:
  107. setattr(vector_data, 'geometry', properties['geom'])
  108. try:
  109. db.add(vector_data)
  110. imported_count += 1
  111. except Exception as e:
  112. continue
  113. # 提交事务
  114. try:
  115. db.commit()
  116. except Exception as e:
  117. db.rollback()
  118. raise ValueError(f"数据库操作失败: {str(e)}")
  119. return {
  120. "message": f"成功导入 {imported_count} 条记录",
  121. "imported_count": imported_count
  122. }
  123. except json.JSONDecodeError as e:
  124. raise ValueError(f"无效的JSON格式: {str(e)}")
  125. except Exception as e:
  126. db.rollback()
  127. raise ValueError(f"导入失败: {str(e)}")
  128. def export_vector_data(db: Session, vector_id: int):
  129. """导出指定ID的矢量数据为GeoJSON格式并保存到文件"""
  130. vector_data = get_vector_data(db, vector_id)
  131. return _export_vector_data_to_file([vector_data], f"export_{vector_id}", "surveydata")
  132. def export_vector_data_batch(db: Session, vector_ids: List[int]):
  133. """批量导出矢量数据为GeoJSON格式并保存到文件"""
  134. vector_data_list = get_vector_data_batch(db, vector_ids)
  135. return _export_vector_data_to_file(vector_data_list, f"export_batch_{'_'.join(map(str, vector_ids))}", "surveydata")
  136. def export_all_vector_data(db: Session, table_name: str = "surveydata"):
  137. """导出指定的矢量数据表为GeoJSON格式并保存到文件
  138. Args:
  139. db (Session): 数据库会话
  140. table_name (str): 要导出的矢量数据表名,默认为'surveydata'
  141. Returns:
  142. dict: 包含导出文件路径和临时目录的字典
  143. """
  144. # 使用动态表名查询
  145. query = text(f"SELECT * FROM {table_name}")
  146. vector_data_list = db.execute(query).fetchall()
  147. # 如果没有数据,抛出异常
  148. if not vector_data_list:
  149. raise HTTPException(status_code=404, detail=f"表 {table_name} 中没有矢量数据")
  150. # 调用现有的导出函数
  151. return _export_vector_data_to_file(vector_data_list, f"export_{table_name}", table_name)
  152. def parse_geom_field(geom_value) -> dict:
  153. """
  154. 解析 geom 字段为 GeoJSON 格式的 geometry。
  155. 如果解析失败,返回 None。
  156. """
  157. try:
  158. # 将 geom_value 转换为字符串
  159. geom_str = str(geom_value)
  160. # 处理PostGIS WKB格式的点数据
  161. if geom_str and geom_str.startswith('0101000020'):
  162. # 去掉前两个字符(字节序标记),并转换为字节对象
  163. binary_geom = bytes.fromhex(geom_str[2:])
  164. # 解析字节序(前1个字节)
  165. byte_order = struct.unpack('B', binary_geom[:1])[0]
  166. endian = '<' if byte_order == 1 else '>'
  167. # 检查数据长度是否足够解析坐标
  168. if len(binary_geom) >= 1 + 16:
  169. # 从数据末尾往前找 16 字节作为坐标
  170. coord_bytes = binary_geom[-16:]
  171. x, y = struct.unpack(f'{endian}dd', coord_bytes)
  172. return {
  173. "type": "Point",
  174. "coordinates": [x, y]
  175. }
  176. else:
  177. print(f"数据长度不足: {geom_str}. 长度: {len(binary_geom)}")
  178. # 处理PostgreSQL/PostGIS的Well-Known Text (WKT)格式
  179. elif geom_str and (geom_str.startswith('POINT') or
  180. geom_str.startswith('LINESTRING') or
  181. geom_str.startswith('POLYGON') or
  182. geom_str.startswith('MULTIPOINT') or
  183. geom_str.startswith('MULTILINESTRING') or
  184. geom_str.startswith('MULTIPOLYGON')):
  185. # 这里我们需要依赖PostgreSQL服务器将WKT转换为GeoJSON
  186. # 在实际部署中,应该使用数据库函数如ST_AsGeoJSON()
  187. print(f"检测到WKT格式几何数据: {geom_str[:30]}...")
  188. return None
  189. # 处理EWKT (Extended Well-Known Text)格式,如SRID=4326;POINT(...)
  190. elif geom_str and geom_str.startswith('SRID='):
  191. print(f"检测到EWKT格式几何数据: {geom_str[:30]}...")
  192. return None
  193. # 处理十六进制WKB格式
  194. elif geom_str and all(c in '0123456789ABCDEFabcdef' for c in geom_str):
  195. print(f"检测到十六进制WKB格式几何数据: {geom_str[:30]}...")
  196. # 使用Shapely库解析WKB (首选方法)
  197. if SHAPELY_AVAILABLE:
  198. try:
  199. # 如果字符串长度为奇数,可能需要在前面添加一个"0"
  200. if len(geom_str) % 2 != 0:
  201. geom_str = '0' + geom_str
  202. # 将十六进制字符串转换为二进制数据
  203. binary_data = binascii.unhexlify(geom_str)
  204. # 使用Shapely解析WKB并转换为GeoJSON
  205. shape = wkb.loads(binary_data)
  206. return mapping(shape)
  207. except Exception as e:
  208. print(f"Shapely解析WKB失败: {e}")
  209. # 使用PostGIS函数进行解析
  210. try:
  211. from ..database import engine
  212. with engine.connect() as connection:
  213. # 使用PostgreSQL/PostGIS的ST_GeomFromWKB函数和ST_AsGeoJSON函数
  214. query = text("SELECT ST_AsGeoJSON(ST_GeomFromWKB(decode($1, 'hex'))) AS geojson")
  215. result = connection.execute(query, [geom_str]).fetchone()
  216. if result and result.geojson:
  217. return json.loads(result.geojson)
  218. except Exception as e:
  219. print(f"使用PostgreSQL解析WKB失败: {e}")
  220. return None
  221. else:
  222. # 可能是使用PostGIS扩展的内部二进制格式
  223. from sqlalchemy.sql import text
  224. from ..database import engine
  225. try:
  226. # 使用PostgreSQL/PostGIS的ST_AsGeoJSON函数直接转换
  227. with engine.connect() as connection:
  228. # 尝试安全地传递geom_value
  229. # 注意:这种方法依赖于数据库连接和PostGIS扩展
  230. query = text("SELECT ST_AsGeoJSON(ST_Force2D($1::geometry)) AS geojson")
  231. result = connection.execute(query, [geom_value]).fetchone()
  232. if result and result.geojson:
  233. return json.loads(result.geojson)
  234. except Exception as e:
  235. print(f"使用ST_AsGeoJSON转换几何数据失败: {e}")
  236. print(f"未识别的几何数据格式: {geom_str[:50]}...")
  237. except (ValueError, IndexError, struct.error) as e:
  238. print(f"解析几何字段失败: {geom_str if 'geom_str' in locals() else geom_value}. 错误: {e}")
  239. return None
  240. def _export_vector_data_to_file(vector_data_list, base_filename: str, table_name: str = "surveydata"):
  241. """将矢量数据列表导出为 GeoJSON 文件
  242. Args:
  243. vector_data_list: 矢量数据列表,可能是ORM对象或SQLAlchemy行对象
  244. base_filename: 基础文件名
  245. table_name: 表名,用于判断应该使用哪个ORM模型,默认为"surveydata"
  246. """
  247. features = []
  248. # 导入所需的ORM模型
  249. from ..models.orm_models import UnitCeil, Surveydatum, FiftyThousandSurveyDatum
  250. # 根据表名获取对应的ORM模型和几何字段名
  251. model_mapping = {
  252. "surveydata": (Surveydatum, "geom"),
  253. "unit_ceil": (UnitCeil, "geom"),
  254. "fifty_thousand_survey_data": (FiftyThousandSurveyDatum, "geom")
  255. }
  256. # 获取对应的模型和几何字段名
  257. model_class, geom_field = model_mapping.get(table_name, (Surveydatum, "geom"))
  258. # 检查数据类型
  259. is_orm_object = len(vector_data_list) == 0 or hasattr(vector_data_list[0], '__table__')
  260. for vector_data in vector_data_list:
  261. # 构建包含所有列数据的字典
  262. data_dict = {}
  263. if is_orm_object:
  264. # 如果是ORM对象,使用模型的列获取数据
  265. columns = [column.name for column in model_class.__table__.columns]
  266. for column in columns:
  267. if hasattr(vector_data, column):
  268. value = getattr(vector_data, column)
  269. # 如果值是字符串且可能是 JSON,尝试解析
  270. if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
  271. try:
  272. value = json.loads(value)
  273. except:
  274. pass
  275. # 跳过几何字段,后续单独处理
  276. if column != geom_field:
  277. data_dict[column] = value
  278. else:
  279. # 如果是SQLAlchemy行对象,获取所有键
  280. for key in vector_data.keys():
  281. if key != geom_field:
  282. value = vector_data[key]
  283. # 如果值是字符串且可能是 JSON,尝试解析
  284. if isinstance(value, str) and (value.startswith('{') or value.startswith('[')):
  285. try:
  286. value = json.loads(value)
  287. except:
  288. pass
  289. data_dict[key] = value
  290. # 解析几何字段为GeoJSON格式的geometry
  291. geometry = None
  292. geom_value = None
  293. if is_orm_object and hasattr(vector_data, geom_field):
  294. geom_value = getattr(vector_data, geom_field)
  295. elif not is_orm_object and geom_field in vector_data.keys():
  296. geom_value = vector_data[geom_field]
  297. if geom_value:
  298. geometry = parse_geom_field(geom_value)
  299. # 创建Feature
  300. feature = {
  301. "type": "Feature",
  302. "properties": data_dict,
  303. "geometry": geometry
  304. }
  305. features.append(feature)
  306. # 创建GeoJSON对象
  307. geojson = {
  308. "type": "FeatureCollection",
  309. "features": features
  310. }
  311. # 创建临时目录
  312. temp_dir = tempfile.mkdtemp()
  313. # 生成文件名(使用时间戳避免重名)
  314. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  315. filename = f"{base_filename}_{timestamp}.geojson"
  316. file_path = os.path.join(temp_dir, filename)
  317. # 保存到文件,使用自定义编码器处理Decimal类型
  318. with open(file_path, "w", encoding="utf-8") as f:
  319. json.dump(geojson, f, ensure_ascii=False, indent=2, cls=DecimalEncoder)
  320. return {
  321. "message": "数据导出成功",
  322. "file_path": file_path,
  323. "temp_dir": temp_dir, # 返回临时目录路径,以便后续清理
  324. "data": geojson
  325. }