unit_grouping_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import func, text
  3. from collections import Counter
  4. from typing import Dict, List, Tuple, Optional
  5. import logging
  6. from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
  7. logger = logging.getLogger(__name__)
  8. class UnitGroupingService:
  9. """
  10. 单元分组服务
  11. 提供基于点位数据的单元h_xtfx值计算功能
  12. """
  13. # 定义 h_xtfx 值的映射关系(数值用于插值计算)
  14. H_XTFX_MAPPING = {
  15. "优先保护类": 1,
  16. "安全利用类": 2,
  17. "严格管控类": 3
  18. }
  19. REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
  20. def __init__(self, db_session: Session):
  21. """
  22. 初始化服务
  23. Args:
  24. db_session: 数据库会话
  25. """
  26. self.db_session = db_session
  27. def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
  28. """
  29. 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
  30. Returns:
  31. Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
  32. """
  33. try:
  34. # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
  35. spatial_query = text("""
  36. SELECT
  37. u.gid as unit_id,
  38. p.h_xtfx,
  39. ST_X(ST_Centroid(u.geom)) as unit_center_x,
  40. ST_Y(ST_Centroid(u.geom)) as unit_center_y,
  41. ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
  42. ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
  43. FROM unit_ceil u
  44. JOIN fifty_thousand_survey_data p ON ST_Contains(
  45. u.geom,
  46. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  47. )
  48. WHERE p.h_xtfx IS NOT NULL
  49. ORDER BY u.gid
  50. """)
  51. spatial_results = self.db_session.execute(spatial_query).fetchall()
  52. # 组织数据:单元ID -> 包含的点位列表
  53. unit_points = {}
  54. for result in spatial_results:
  55. unit_id = result.unit_id
  56. h_xtfx = result.h_xtfx
  57. point_x = result.point_x
  58. point_y = result.point_y
  59. if unit_id not in unit_points:
  60. unit_points[unit_id] = []
  61. unit_points[unit_id].append({
  62. 'h_xtfx': h_xtfx,
  63. 'x': point_x,
  64. 'y': point_y
  65. })
  66. # 计算每个单元的h_xtfx值
  67. result = {}
  68. # 获取所有单元的ID
  69. all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
  70. for unit_row in all_units:
  71. unit_id = unit_row.gid
  72. points = unit_points.get(unit_id, [])
  73. if not points:
  74. result[unit_id] = None
  75. continue
  76. try:
  77. result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
  78. except Exception as e:
  79. logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
  80. result[unit_id] = None
  81. logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
  82. return result
  83. except Exception as e:
  84. logger.error(f"计算单元h_xtfx值失败: {e}")
  85. return {}
  86. def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
  87. """
  88. 基于点位列表计算单个单元的h_xtfx值
  89. Args:
  90. unit_id: 单元ID
  91. points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
  92. Returns:
  93. str: h_xtfx值
  94. """
  95. if not points:
  96. return None
  97. h_xtfx_list = [point['h_xtfx'] for point in points]
  98. has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
  99. if not has_strict_control:
  100. # 无严格管控类:先判断比例是否 ≥80%
  101. counter = Counter(h_xtfx_list)
  102. most_common, count = counter.most_common(1)[0]
  103. if count / len(points) >= 0.8:
  104. return most_common
  105. # 比例不达标:对优先保护类和安全利用类进行插值
  106. valid_points = [
  107. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  108. for point in points
  109. if point['h_xtfx'] in ["优先保护类", "安全利用类"]
  110. ]
  111. if len(valid_points) < 2:
  112. # 有效点位不足,取最常见值
  113. return most_common
  114. else:
  115. # 获取单元中心点坐标
  116. unit_center = self._get_unit_center(unit_id)
  117. if unit_center is None:
  118. return most_common
  119. interpolated = self._idw_interpolation_simple(valid_points, unit_center)
  120. if interpolated is None:
  121. return most_common
  122. return self._interpolated_value_to_category(interpolated)
  123. else:
  124. # 存在严格管控类:对所有点位进行插值
  125. all_points = [
  126. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  127. for point in points
  128. if point['h_xtfx'] in self.H_XTFX_MAPPING
  129. ]
  130. if not all_points:
  131. return None
  132. # 获取单元中心点坐标
  133. unit_center = self._get_unit_center(unit_id)
  134. if unit_center is None:
  135. return None
  136. interpolated = self._idw_interpolation_simple(all_points, unit_center)
  137. if interpolated is None:
  138. return None
  139. return self._interpolated_value_to_category(interpolated)
  140. def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
  141. """
  142. 获取单元中心点坐标
  143. Args:
  144. unit_id: 单元ID
  145. Returns:
  146. Tuple[float, float]: (x, y) 坐标
  147. """
  148. try:
  149. center_query = text("""
  150. SELECT
  151. ST_X(ST_Centroid(geom)) as center_x,
  152. ST_Y(ST_Centroid(geom)) as center_y
  153. FROM unit_ceil
  154. WHERE gid = :unit_id
  155. """)
  156. result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
  157. if result:
  158. return (result.center_x, result.center_y)
  159. return None
  160. except Exception as e:
  161. logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
  162. return None
  163. def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
  164. """
  165. 简单的反距离加权插值函数
  166. Args:
  167. points: 点位坐标和值的列表 [(x, y, value), ...]
  168. target_point: 目标点位坐标 (x, y)
  169. Returns:
  170. float: 插值结果
  171. """
  172. if not points:
  173. return None
  174. total_weight = 0
  175. weighted_sum = 0
  176. power = 2 # 距离权重的幂
  177. target_x, target_y = target_point
  178. for point_x, point_y, value in points:
  179. # 计算欧几里得距离
  180. distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
  181. if distance == 0:
  182. return value # 距离为0时直接返回该点值
  183. weight = 1 / (distance ** power)
  184. total_weight += weight
  185. weighted_sum += weight * value
  186. return weighted_sum / total_weight if total_weight != 0 else None
  187. def _interpolated_value_to_category(self, interpolated_value: float) -> str:
  188. """
  189. 将插值结果转换为类别
  190. Args:
  191. interpolated_value: 插值结果
  192. Returns:
  193. str: 类别名称
  194. """
  195. if interpolated_value <= 1.5:
  196. return "优先保护类"
  197. elif interpolated_value <= 2.5:
  198. return "安全利用类"
  199. else:
  200. return "严格管控类"
  201. def get_unit_count(self) -> int:
  202. """
  203. 获取单元总数
  204. Returns:
  205. int: 单元总数
  206. """
  207. return self.db_session.query(UnitCeil).count()
  208. def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
  209. """
  210. 获取不同h_xtfx类别的点位数量统计
  211. Returns:
  212. Dict[str, int]: 各类别点位数量
  213. """
  214. try:
  215. result = self.db_session.query(
  216. FiftyThousandSurveyDatum.h_xtfx,
  217. func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
  218. ).filter(
  219. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  220. ).group_by(
  221. FiftyThousandSurveyDatum.h_xtfx
  222. ).all()
  223. return {row.h_xtfx: row.count for row in result}
  224. except Exception as e:
  225. logger.error(f"获取点位统计失败: {e}")
  226. return {}
  227. def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
  228. """
  229. 批量获取单元信息
  230. Args:
  231. unit_ids: 单元ID列表
  232. Returns:
  233. List[UnitCeil]: 单元对象列表
  234. """
  235. return self.db_session.query(UnitCeil).filter(
  236. UnitCeil.gid.in_(unit_ids)
  237. ).all()
  238. def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
  239. """
  240. 获取特定区域的点位数据
  241. Args:
  242. area_name: 区域名称(县名称)
  243. Returns:
  244. List[FiftyThousandSurveyDatum]: 点位数据列表
  245. """
  246. query = self.db_session.query(FiftyThousandSurveyDatum).filter(
  247. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  248. )
  249. if area_name:
  250. query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
  251. return query.all()
  252. def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
  253. """
  254. 优化的空间查询:找出包含指定点位的单元
  255. Args:
  256. point_ids: 点位ID列表
  257. Returns:
  258. List[Tuple[int, int]]: (unit_id, point_id) 对的列表
  259. """
  260. try:
  261. # 使用数据库端的空间查询
  262. spatial_query = text("""
  263. SELECT
  264. u.gid as unit_id,
  265. p.id as point_id
  266. FROM unit_ceil u
  267. JOIN fifty_thousand_survey_data p ON ST_Contains(
  268. u.geom,
  269. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  270. )
  271. WHERE p.id = ANY(:point_ids)
  272. """)
  273. result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
  274. return [(row.unit_id, row.point_id) for row in result]
  275. except Exception as e:
  276. logger.error(f"优化空间查询失败: {e}")
  277. return []
  278. def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
  279. """
  280. 获取各区域的统计信息
  281. Returns:
  282. Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
  283. """
  284. try:
  285. result = self.db_session.query(
  286. FiftyThousandSurveyDatum.xmc,
  287. FiftyThousandSurveyDatum.h_xtfx,
  288. func.count(FiftyThousandSurveyDatum.id).label('count')
  289. ).filter(
  290. FiftyThousandSurveyDatum.h_xtfx.isnot(None),
  291. FiftyThousandSurveyDatum.xmc.isnot(None)
  292. ).group_by(
  293. FiftyThousandSurveyDatum.xmc,
  294. FiftyThousandSurveyDatum.h_xtfx
  295. ).order_by(
  296. FiftyThousandSurveyDatum.xmc
  297. ).all()
  298. area_stats = {}
  299. for row in result:
  300. area_name = row.xmc
  301. h_xtfx = row.h_xtfx
  302. count = row.count
  303. if area_name not in area_stats:
  304. area_stats[area_name] = {}
  305. area_stats[area_name][h_xtfx] = count
  306. return area_stats
  307. except Exception as e:
  308. logger.error(f"获取区域统计信息失败: {e}")
  309. return {}
  310. def get_unit_h_xtfx_result(self) -> Dict[str, any]:
  311. """
  312. 获取单元h_xtfx结果的API方法
  313. Returns:
  314. Dict[str, any]: API响应数据
  315. """
  316. try:
  317. result = self.calculate_unit_h_xtfx_values()
  318. total_units = self.get_unit_count()
  319. units_with_data = sum(1 for v in result.values() if v is not None)
  320. # 按类别统计
  321. category_stats = {}
  322. for category in self.H_XTFX_MAPPING.keys():
  323. category_stats[category] = sum(1 for v in result.values() if v == category)
  324. point_stats = self.get_point_count_by_h_xtfx()
  325. return {
  326. "success": True,
  327. "data": result,
  328. "statistics": {
  329. "total_units": total_units,
  330. "units_with_data": units_with_data,
  331. "units_without_data": total_units - units_with_data,
  332. "category_distribution": category_stats,
  333. "point_distribution": point_stats
  334. }
  335. }
  336. except Exception as e:
  337. logger.error(f"获取单元h_xtfx结果失败: {e}")
  338. return {
  339. "success": False,
  340. "error": str(e),
  341. "data": None
  342. }