unit_grouping_service.py 15 KB


  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import func, text
  3. from collections import Counter
  4. from typing import Dict, List, Tuple, Optional
  5. import logging
  6. from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
  7. # 配置日志
  8. from app.log.logger import get_logger
  9. logger = get_logger(__name__)
  10. class UnitGroupingService:
  11. """
  12. 单元分组服务
  13. 提供基于点位数据的单元h_xtfx值计算功能
  14. """
  15. # 定义 h_xtfx 值的映射关系(数值用于插值计算)
  16. H_XTFX_MAPPING = {
  17. "优先保护类": 1,
  18. "安全利用类": 2,
  19. "严格管控类": 3
  20. }
  21. REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
  22. def __init__(self, db_session: Session):
  23. """
  24. 初始化服务
  25. Args:
  26. db_session: 数据库会话
  27. """
  28. self.db_session = db_session
  29. def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
  30. """
  31. 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
  32. Returns:
  33. Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
  34. """
  35. try:
  36. # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
  37. spatial_query = text("""
  38. SELECT
  39. u.gid as unit_id,
  40. p.h_xtfx,
  41. ST_X(ST_Centroid(u.geom)) as unit_center_x,
  42. ST_Y(ST_Centroid(u.geom)) as unit_center_y,
  43. ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
  44. ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
  45. FROM unit_ceil u
  46. JOIN fifty_thousand_survey_data p ON ST_Contains(
  47. u.geom,
  48. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  49. )
  50. WHERE p.h_xtfx IS NOT NULL
  51. ORDER BY u.gid
  52. """)
  53. spatial_results = self.db_session.execute(spatial_query).fetchall()
  54. # 组织数据:单元ID -> 包含的点位列表
  55. unit_points = {}
  56. for result in spatial_results:
  57. unit_id = result.unit_id
  58. h_xtfx = result.h_xtfx
  59. point_x = result.point_x
  60. point_y = result.point_y
  61. if unit_id not in unit_points:
  62. unit_points[unit_id] = []
  63. unit_points[unit_id].append({
  64. 'h_xtfx': h_xtfx,
  65. 'x': point_x,
  66. 'y': point_y
  67. })
  68. # 计算每个单元的h_xtfx值
  69. result = {}
  70. # 获取所有单元的ID
  71. all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
  72. for unit_row in all_units:
  73. unit_id = unit_row.gid
  74. points = unit_points.get(unit_id, [])
  75. if not points:
  76. result[unit_id] = None
  77. continue
  78. try:
  79. result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
  80. except Exception as e:
  81. logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
  82. result[unit_id] = None
  83. logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
  84. return result
  85. except Exception as e:
  86. logger.error(f"计算单元h_xtfx值失败: {e}")
  87. return {}
  88. def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
  89. """
  90. 基于点位列表计算单个单元的h_xtfx值
  91. Args:
  92. unit_id: 单元ID
  93. points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
  94. Returns:
  95. str: h_xtfx值
  96. """
  97. if not points:
  98. return None
  99. h_xtfx_list = [point['h_xtfx'] for point in points]
  100. has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
  101. if not has_strict_control:
  102. # 无严格管控类:先判断比例是否 ≥80%
  103. counter = Counter(h_xtfx_list)
  104. most_common, count = counter.most_common(1)[0]
  105. if count / len(points) >= 0.8:
  106. return most_common
  107. # 比例不达标:对优先保护类和安全利用类进行插值
  108. valid_points = [
  109. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  110. for point in points
  111. if point['h_xtfx'] in ["优先保护类", "安全利用类"]
  112. ]
  113. if len(valid_points) < 2:
  114. # 有效点位不足,取最常见值
  115. return most_common
  116. else:
  117. # 获取单元中心点坐标
  118. unit_center = self._get_unit_center(unit_id)
  119. if unit_center is None:
  120. return most_common
  121. interpolated = self._idw_interpolation_simple(valid_points, unit_center)
  122. if interpolated is None:
  123. return most_common
  124. return self._interpolated_value_to_category(interpolated)
  125. else:
  126. # 存在严格管控类:对所有点位进行插值
  127. all_points = [
  128. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  129. for point in points
  130. if point['h_xtfx'] in self.H_XTFX_MAPPING
  131. ]
  132. if not all_points:
  133. return None
  134. # 获取单元中心点坐标
  135. unit_center = self._get_unit_center(unit_id)
  136. if unit_center is None:
  137. return None
  138. interpolated = self._idw_interpolation_simple(all_points, unit_center)
  139. if interpolated is None:
  140. return None
  141. return self._interpolated_value_to_category(interpolated)
  142. def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
  143. """
  144. 获取单元中心点坐标
  145. Args:
  146. unit_id: 单元ID
  147. Returns:
  148. Tuple[float, float]: (x, y) 坐标
  149. """
  150. try:
  151. center_query = text("""
  152. SELECT
  153. ST_X(ST_Centroid(geom)) as center_x,
  154. ST_Y(ST_Centroid(geom)) as center_y
  155. FROM unit_ceil
  156. WHERE gid = :unit_id
  157. """)
  158. result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
  159. if result:
  160. return (result.center_x, result.center_y)
  161. return None
  162. except Exception as e:
  163. logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
  164. return None
  165. def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
  166. """
  167. 简单的反距离加权插值函数
  168. Args:
  169. points: 点位坐标和值的列表 [(x, y, value), ...]
  170. target_point: 目标点位坐标 (x, y)
  171. Returns:
  172. float: 插值结果
  173. """
  174. if not points:
  175. return None
  176. total_weight = 0
  177. weighted_sum = 0
  178. power = 2 # 距离权重的幂
  179. target_x, target_y = target_point
  180. for point_x, point_y, value in points:
  181. # 计算欧几里得距离
  182. distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
  183. if distance == 0:
  184. return value # 距离为0时直接返回该点值
  185. weight = 1 / (distance ** power)
  186. total_weight += weight
  187. weighted_sum += weight * value
  188. return weighted_sum / total_weight if total_weight != 0 else None
  189. def _interpolated_value_to_category(self, interpolated_value: float) -> str:
  190. """
  191. 将插值结果转换为类别
  192. Args:
  193. interpolated_value: 插值结果
  194. Returns:
  195. str: 类别名称
  196. """
  197. if interpolated_value <= 1.5:
  198. return "优先保护类"
  199. elif interpolated_value <= 2.5:
  200. return "安全利用类"
  201. else:
  202. return "严格管控类"
  203. def get_unit_count(self) -> int:
  204. """
  205. 获取单元总数
  206. Returns:
  207. int: 单元总数
  208. """
  209. return self.db_session.query(UnitCeil).count()
  210. def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
  211. """
  212. 获取不同h_xtfx类别的点位数量统计
  213. Returns:
  214. Dict[str, int]: 各类别点位数量
  215. """
  216. try:
  217. result = self.db_session.query(
  218. FiftyThousandSurveyDatum.h_xtfx,
  219. func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
  220. ).filter(
  221. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  222. ).group_by(
  223. FiftyThousandSurveyDatum.h_xtfx
  224. ).all()
  225. return {row.h_xtfx: row.count for row in result}
  226. except Exception as e:
  227. logger.error(f"获取点位统计失败: {e}")
  228. return {}
  229. def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
  230. """
  231. 批量获取单元信息
  232. Args:
  233. unit_ids: 单元ID列表
  234. Returns:
  235. List[UnitCeil]: 单元对象列表
  236. """
  237. return self.db_session.query(UnitCeil).filter(
  238. UnitCeil.gid.in_(unit_ids)
  239. ).all()
  240. def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
  241. """
  242. 获取特定区域的点位数据
  243. Args:
  244. area_name: 区域名称(县名称)
  245. Returns:
  246. List[FiftyThousandSurveyDatum]: 点位数据列表
  247. """
  248. query = self.db_session.query(FiftyThousandSurveyDatum).filter(
  249. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  250. )
  251. if area_name:
  252. query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
  253. return query.all()
  254. def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
  255. """
  256. 优化的空间查询:找出包含指定点位的单元
  257. Args:
  258. point_ids: 点位ID列表
  259. Returns:
  260. List[Tuple[int, int]]: (unit_id, point_id) 对的列表
  261. """
  262. try:
  263. # 使用数据库端的空间查询
  264. spatial_query = text("""
  265. SELECT
  266. u.gid as unit_id,
  267. p.id as point_id
  268. FROM unit_ceil u
  269. JOIN fifty_thousand_survey_data p ON ST_Contains(
  270. u.geom,
  271. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  272. )
  273. WHERE p.id = ANY(:point_ids)
  274. """)
  275. result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
  276. return [(row.unit_id, row.point_id) for row in result]
  277. except Exception as e:
  278. logger.error(f"优化空间查询失败: {e}")
  279. return []
  280. def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
  281. """
  282. 获取各区域的统计信息
  283. Returns:
  284. Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
  285. """
  286. try:
  287. result = self.db_session.query(
  288. FiftyThousandSurveyDatum.xmc,
  289. FiftyThousandSurveyDatum.h_xtfx,
  290. func.count(FiftyThousandSurveyDatum.id).label('count')
  291. ).filter(
  292. FiftyThousandSurveyDatum.h_xtfx.isnot(None),
  293. FiftyThousandSurveyDatum.xmc.isnot(None)
  294. ).group_by(
  295. FiftyThousandSurveyDatum.xmc,
  296. FiftyThousandSurveyDatum.h_xtfx
  297. ).order_by(
  298. FiftyThousandSurveyDatum.xmc
  299. ).all()
  300. area_stats = {}
  301. for row in result:
  302. area_name = row.xmc
  303. h_xtfx = row.h_xtfx
  304. count = row.count
  305. if area_name not in area_stats:
  306. area_stats[area_name] = {}
  307. area_stats[area_name][h_xtfx] = count
  308. return area_stats
  309. except Exception as e:
  310. logger.error(f"获取区域统计信息失败: {e}")
  311. return {}
  312. def get_unit_h_xtfx_result(self) -> Dict[str, any]:
  313. """
  314. 获取单元h_xtfx结果的API方法,同时返回gid和OBJECTID
  315. Returns:
  316. Dict[str, any]: API响应数据,包含单元的gid和OBJECTID
  317. """
  318. try:
  319. # 计算单元结果
  320. unit_values = self.calculate_unit_h_xtfx_values()
  321. # 获取所有单元的gid和OBJECTID
  322. unit_ids = list(unit_values.keys())
  323. units = self.db_session.query(UnitCeil.gid, UnitCeil.OBJECTID).filter(
  324. UnitCeil.gid.in_(unit_ids)
  325. ).all()
  326. # 构建包含gid和OBJECTID的结果列表
  327. result_with_ids = []
  328. for unit in units:
  329. result_with_ids.append({
  330. "gid": unit.gid,
  331. "OBJECTID": unit.OBJECTID,
  332. "h_xtfx": unit_values.get(unit.gid)
  333. })
  334. # 统计信息
  335. total_units = self.get_unit_count()
  336. units_with_data = sum(1 for item in result_with_ids if item['h_xtfx'] is not None)
  337. category_stats = {key: 0 for key in self.H_XTFX_MAPPING.keys()}
  338. for item in result_with_ids:
  339. if item['h_xtfx'] in category_stats:
  340. category_stats[item['h_xtfx']] += 1
  341. point_stats = self.get_point_count_by_h_xtfx()
  342. return {
  343. "success": True,
  344. "data": result_with_ids, # 包含gid和OBJECTID的数组
  345. "statistics": {
  346. "total_units": total_units,
  347. "units_with_data": units_with_data,
  348. "units_without_data": total_units - units_with_data,
  349. "category_distribution": category_stats,
  350. "point_distribution": point_stats
  351. }
  352. }
  353. except Exception as e:
  354. logger.error(f"获取单元h_xtfx结果失败: {e}")
  355. return {
  356. "success": False,
  357. "error": str(e),
  358. "data": None
  359. }