unit_grouping_service.py 14 KB


  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import func, text
  3. from collections import Counter
  4. from typing import Dict, List, Tuple, Optional
  5. import logging
  6. from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
  7. logger = logging.getLogger(__name__)
  8. class UnitGroupingService:
  9. """
  10. 单元分组服务
  11. 提供基于点位数据的单元h_xtfx值计算功能
  12. """
  13. # 定义 h_xtfx 值的映射关系(数值用于插值计算)
  14. H_XTFX_MAPPING = {
  15. "优先保护类": 1,
  16. "安全利用类": 2,
  17. "严格管控类": 3
  18. }
  19. REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
  20. def __init__(self, db_session: Session):
  21. """
  22. 初始化服务
  23. Args:
  24. db_session: 数据库会话
  25. """
  26. self.db_session = db_session
  27. def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
  28. """
  29. 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
  30. Returns:
  31. Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
  32. """
  33. try:
  34. # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
  35. spatial_query = text("""
  36. SELECT
  37. u.gid as unit_id,
  38. p.h_xtfx,
  39. ST_X(ST_Centroid(u.geom)) as unit_center_x,
  40. ST_Y(ST_Centroid(u.geom)) as unit_center_y,
  41. ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
  42. ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
  43. FROM unit_ceil u
  44. JOIN fifty_thousand_survey_data p ON ST_Contains(
  45. u.geom,
  46. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  47. )
  48. WHERE p.h_xtfx IS NOT NULL
  49. ORDER BY u.gid
  50. """)
  51. spatial_results = self.db_session.execute(spatial_query).fetchall()
  52. # 组织数据:单元ID -> 包含的点位列表
  53. unit_points = {}
  54. for result in spatial_results:
  55. unit_id = result.unit_id
  56. h_xtfx = result.h_xtfx
  57. point_x = result.point_x
  58. point_y = result.point_y
  59. if unit_id not in unit_points:
  60. unit_points[unit_id] = []
  61. unit_points[unit_id].append({
  62. 'h_xtfx': h_xtfx,
  63. 'x': point_x,
  64. 'y': point_y
  65. })
  66. # 计算每个单元的h_xtfx值
  67. result = {}
  68. # 获取所有单元的ID
  69. all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
  70. for unit_row in all_units:
  71. unit_id = unit_row.gid
  72. points = unit_points.get(unit_id, [])
  73. if not points:
  74. result[unit_id] = None
  75. continue
  76. try:
  77. result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
  78. except Exception as e:
  79. logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
  80. result[unit_id] = None
  81. logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
  82. return result
  83. except Exception as e:
  84. logger.error(f"计算单元h_xtfx值失败: {e}")
  85. return {}
  86. def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
  87. """
  88. 基于点位列表计算单个单元的h_xtfx值
  89. Args:
  90. unit_id: 单元ID
  91. points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
  92. Returns:
  93. str: h_xtfx值
  94. """
  95. if not points:
  96. return None
  97. h_xtfx_list = [point['h_xtfx'] for point in points]
  98. has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
  99. if not has_strict_control:
  100. # 无严格管控类:先判断比例是否 ≥80%
  101. counter = Counter(h_xtfx_list)
  102. most_common, count = counter.most_common(1)[0]
  103. if count / len(points) >= 0.8:
  104. return most_common
  105. # 比例不达标:对优先保护类和安全利用类进行插值
  106. valid_points = [
  107. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  108. for point in points
  109. if point['h_xtfx'] in ["优先保护类", "安全利用类"]
  110. ]
  111. if len(valid_points) < 2:
  112. # 有效点位不足,取最常见值
  113. return most_common
  114. else:
  115. # 获取单元中心点坐标
  116. unit_center = self._get_unit_center(unit_id)
  117. if unit_center is None:
  118. return most_common
  119. interpolated = self._idw_interpolation_simple(valid_points, unit_center)
  120. if interpolated is None:
  121. return most_common
  122. return self._interpolated_value_to_category(interpolated)
  123. else:
  124. # 存在严格管控类:对所有点位进行插值
  125. all_points = [
  126. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  127. for point in points
  128. if point['h_xtfx'] in self.H_XTFX_MAPPING
  129. ]
  130. if not all_points:
  131. return None
  132. # 获取单元中心点坐标
  133. unit_center = self._get_unit_center(unit_id)
  134. if unit_center is None:
  135. return None
  136. interpolated = self._idw_interpolation_simple(all_points, unit_center)
  137. if interpolated is None:
  138. return None
  139. return self._interpolated_value_to_category(interpolated)
  140. def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
  141. """
  142. 获取单元中心点坐标
  143. Args:
  144. unit_id: 单元ID
  145. Returns:
  146. Tuple[float, float]: (x, y) 坐标
  147. """
  148. try:
  149. center_query = text("""
  150. SELECT
  151. ST_X(ST_Centroid(geom)) as center_x,
  152. ST_Y(ST_Centroid(geom)) as center_y
  153. FROM unit_ceil
  154. WHERE gid = :unit_id
  155. """)
  156. result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
  157. if result:
  158. return (result.center_x, result.center_y)
  159. return None
  160. except Exception as e:
  161. logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
  162. return None
  163. def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
  164. """
  165. 简单的反距离加权插值函数
  166. Args:
  167. points: 点位坐标和值的列表 [(x, y, value), ...]
  168. target_point: 目标点位坐标 (x, y)
  169. Returns:
  170. float: 插值结果
  171. """
  172. if not points:
  173. return None
  174. total_weight = 0
  175. weighted_sum = 0
  176. power = 2 # 距离权重的幂
  177. target_x, target_y = target_point
  178. for point_x, point_y, value in points:
  179. # 计算欧几里得距离
  180. distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
  181. if distance == 0:
  182. return value # 距离为0时直接返回该点值
  183. weight = 1 / (distance ** power)
  184. total_weight += weight
  185. weighted_sum += weight * value
  186. return weighted_sum / total_weight if total_weight != 0 else None
  187. def _interpolated_value_to_category(self, interpolated_value: float) -> str:
  188. """
  189. 将插值结果转换为类别
  190. Args:
  191. interpolated_value: 插值结果
  192. Returns:
  193. str: 类别名称
  194. """
  195. if interpolated_value <= 1.5:
  196. return "优先保护类"
  197. elif interpolated_value <= 2.5:
  198. return "安全利用类"
  199. else:
  200. return "严格管控类"
  201. def get_unit_count(self) -> int:
  202. """
  203. 获取单元总数
  204. Returns:
  205. int: 单元总数
  206. """
  207. return self.db_session.query(UnitCeil).count()
  208. def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
  209. """
  210. 获取不同h_xtfx类别的点位数量统计
  211. Returns:
  212. Dict[str, int]: 各类别点位数量
  213. """
  214. try:
  215. result = self.db_session.query(
  216. FiftyThousandSurveyDatum.h_xtfx,
  217. func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
  218. ).filter(
  219. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  220. ).group_by(
  221. FiftyThousandSurveyDatum.h_xtfx
  222. ).all()
  223. return {row.h_xtfx: row.count for row in result}
  224. except Exception as e:
  225. logger.error(f"获取点位统计失败: {e}")
  226. return {}
  227. def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
  228. """
  229. 批量获取单元信息
  230. Args:
  231. unit_ids: 单元ID列表
  232. Returns:
  233. List[UnitCeil]: 单元对象列表
  234. """
  235. return self.db_session.query(UnitCeil).filter(
  236. UnitCeil.gid.in_(unit_ids)
  237. ).all()
  238. def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
  239. """
  240. 获取特定区域的点位数据
  241. Args:
  242. area_name: 区域名称(县名称)
  243. Returns:
  244. List[FiftyThousandSurveyDatum]: 点位数据列表
  245. """
  246. query = self.db_session.query(FiftyThousandSurveyDatum).filter(
  247. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  248. )
  249. if area_name:
  250. query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
  251. return query.all()
  252. def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
  253. """
  254. 优化的空间查询:找出包含指定点位的单元
  255. Args:
  256. point_ids: 点位ID列表
  257. Returns:
  258. List[Tuple[int, int]]: (unit_id, point_id) 对的列表
  259. """
  260. try:
  261. # 使用数据库端的空间查询
  262. spatial_query = text("""
  263. SELECT
  264. u.gid as unit_id,
  265. p.id as point_id
  266. FROM unit_ceil u
  267. JOIN fifty_thousand_survey_data p ON ST_Contains(
  268. u.geom,
  269. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  270. )
  271. WHERE p.id = ANY(:point_ids)
  272. """)
  273. result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
  274. return [(row.unit_id, row.point_id) for row in result]
  275. except Exception as e:
  276. logger.error(f"优化空间查询失败: {e}")
  277. return []
  278. def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
  279. """
  280. 获取各区域的统计信息
  281. Returns:
  282. Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
  283. """
  284. try:
  285. result = self.db_session.query(
  286. FiftyThousandSurveyDatum.xmc,
  287. FiftyThousandSurveyDatum.h_xtfx,
  288. func.count(FiftyThousandSurveyDatum.id).label('count')
  289. ).filter(
  290. FiftyThousandSurveyDatum.h_xtfx.isnot(None),
  291. FiftyThousandSurveyDatum.xmc.isnot(None)
  292. ).group_by(
  293. FiftyThousandSurveyDatum.xmc,
  294. FiftyThousandSurveyDatum.h_xtfx
  295. ).order_by(
  296. FiftyThousandSurveyDatum.xmc
  297. ).all()
  298. area_stats = {}
  299. for row in result:
  300. area_name = row.xmc
  301. h_xtfx = row.h_xtfx
  302. count = row.count
  303. if area_name not in area_stats:
  304. area_stats[area_name] = {}
  305. area_stats[area_name][h_xtfx] = count
  306. return area_stats
  307. except Exception as e:
  308. logger.error(f"获取区域统计信息失败: {e}")
  309. return {}
  310. def get_unit_h_xtfx_result(self) -> Dict[str, any]:
  311. """
  312. 获取单元h_xtfx结果的API方法
  313. Returns:
  314. Dict[str, any]: API响应数据
  315. """
  316. try:
  317. result = self.calculate_unit_h_xtfx_values()
  318. total_units = self.get_unit_count()
  319. units_with_data = sum(1 for v in result.values() if v is not None)
  320. # 按类别统计
  321. category_stats = {}
  322. for category in self.H_XTFX_MAPPING.keys():
  323. category_stats[category] = sum(1 for v in result.values() if v == category)
  324. point_stats = self.get_point_count_by_h_xtfx()
  325. return {
  326. "success": True,
  327. "data": result,
  328. "statistics": {
  329. "total_units": total_units,
  330. "units_with_data": units_with_data,
  331. "units_without_data": total_units - units_with_data,
  332. "category_distribution": category_stats,
  333. "point_distribution": point_stats
  334. }
  335. }
  336. except Exception as e:
  337. logger.error(f"获取单元h_xtfx结果失败: {e}")
  338. return {
  339. "success": False,
  340. "error": str(e),
  341. "data": None
  342. }