unit_grouping_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import func, text
  3. from collections import Counter
  4. from typing import Dict, List, Tuple, Optional
  5. import logging
  6. from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
  7. # 配置日志
  8. from app.log.logger import get_logger
  9. logger = get_logger(__name__)
  10. class UnitGroupingService:
  11. """
  12. 单元分组服务
  13. 提供基于点位数据的单元h_xtfx值计算功能
  14. """
  15. # 定义 h_xtfx 值的映射关系(数值用于插值计算)
  16. H_XTFX_MAPPING = {
  17. "优先保护类": 1,
  18. "安全利用类": 2,
  19. "严格管控类": 3
  20. }
  21. REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
  22. def __init__(self, db_session: Session):
  23. """
  24. 初始化服务
  25. Args:
  26. db_session: 数据库会话
  27. """
  28. self.db_session = db_session
  29. def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
  30. """
  31. 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
  32. Returns:
  33. Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
  34. """
  35. try:
  36. # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
  37. spatial_query = text("""
  38. SELECT
  39. u.gid as unit_id,
  40. p.h_xtfx,
  41. ST_X(ST_Centroid(u.geom)) as unit_center_x,
  42. ST_Y(ST_Centroid(u.geom)) as unit_center_y,
  43. ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
  44. ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
  45. FROM unit_ceil u
  46. JOIN fifty_thousand_survey_data p ON ST_Contains(
  47. u.geom,
  48. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  49. )
  50. WHERE p.h_xtfx IS NOT NULL
  51. ORDER BY u.gid
  52. """)
  53. spatial_results = self.db_session.execute(spatial_query).fetchall()
  54. # 组织数据:单元ID -> 包含的点位列表
  55. unit_points = {}
  56. for result in spatial_results:
  57. unit_id = result.unit_id
  58. h_xtfx = result.h_xtfx
  59. point_x = result.point_x
  60. point_y = result.point_y
  61. if unit_id not in unit_points:
  62. unit_points[unit_id] = []
  63. unit_points[unit_id].append({
  64. 'h_xtfx': h_xtfx,
  65. 'x': point_x,
  66. 'y': point_y
  67. })
  68. # 计算每个单元的h_xtfx值
  69. result = {}
  70. # 获取所有单元的ID
  71. all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
  72. for unit_row in all_units:
  73. unit_id = unit_row.gid
  74. points = unit_points.get(unit_id, [])
  75. if not points:
  76. result[unit_id] = None
  77. continue
  78. try:
  79. result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
  80. except Exception as e:
  81. logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
  82. result[unit_id] = None
  83. logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
  84. return result
  85. except Exception as e:
  86. logger.error(f"计算单元h_xtfx值失败: {e}")
  87. return {}
  88. def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
  89. """
  90. 基于点位列表计算单个单元的h_xtfx值
  91. Args:
  92. unit_id: 单元ID
  93. points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
  94. Returns:
  95. str: h_xtfx值
  96. """
  97. if not points:
  98. return None
  99. h_xtfx_list = [point['h_xtfx'] for point in points]
  100. has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
  101. if not has_strict_control:
  102. # 无严格管控类:先判断比例是否 ≥80%
  103. counter = Counter(h_xtfx_list)
  104. most_common, count = counter.most_common(1)[0]
  105. if count / len(points) >= 0.8:
  106. return most_common
  107. # 比例不达标:对优先保护类和安全利用类进行插值
  108. valid_points = [
  109. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  110. for point in points
  111. if point['h_xtfx'] in ["优先保护类", "安全利用类"]
  112. ]
  113. if len(valid_points) < 2:
  114. # 有效点位不足,取最常见值
  115. return most_common
  116. else:
  117. # 获取单元中心点坐标
  118. unit_center = self._get_unit_center(unit_id)
  119. if unit_center is None:
  120. return most_common
  121. interpolated = self._idw_interpolation_simple(valid_points, unit_center)
  122. if interpolated is None:
  123. return most_common
  124. return self._interpolated_value_to_category(interpolated)
  125. else:
  126. # 存在严格管控类:对所有点位进行插值
  127. all_points = [
  128. (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
  129. for point in points
  130. if point['h_xtfx'] in self.H_XTFX_MAPPING
  131. ]
  132. if not all_points:
  133. return None
  134. # 获取单元中心点坐标
  135. unit_center = self._get_unit_center(unit_id)
  136. if unit_center is None:
  137. return None
  138. interpolated = self._idw_interpolation_simple(all_points, unit_center)
  139. if interpolated is None:
  140. return None
  141. return self._interpolated_value_to_category(interpolated)
  142. def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
  143. """
  144. 获取单元中心点坐标
  145. Args:
  146. unit_id: 单元ID
  147. Returns:
  148. Tuple[float, float]: (x, y) 坐标
  149. """
  150. try:
  151. center_query = text("""
  152. SELECT
  153. ST_X(ST_Centroid(geom)) as center_x,
  154. ST_Y(ST_Centroid(geom)) as center_y
  155. FROM unit_ceil
  156. WHERE gid = :unit_id
  157. """)
  158. result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
  159. if result:
  160. return (result.center_x, result.center_y)
  161. return None
  162. except Exception as e:
  163. logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
  164. return None
  165. def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
  166. """
  167. 简单的反距离加权插值函数
  168. Args:
  169. points: 点位坐标和值的列表 [(x, y, value), ...]
  170. target_point: 目标点位坐标 (x, y)
  171. Returns:
  172. float: 插值结果
  173. """
  174. if not points:
  175. return None
  176. total_weight = 0
  177. weighted_sum = 0
  178. power = 2 # 距离权重的幂
  179. target_x, target_y = target_point
  180. for point_x, point_y, value in points:
  181. # 计算欧几里得距离
  182. distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
  183. if distance == 0:
  184. return value # 距离为0时直接返回该点值
  185. weight = 1 / (distance ** power)
  186. total_weight += weight
  187. weighted_sum += weight * value
  188. return weighted_sum / total_weight if total_weight != 0 else None
  189. def _interpolated_value_to_category(self, interpolated_value: float) -> str:
  190. """
  191. 将插值结果转换为类别
  192. Args:
  193. interpolated_value: 插值结果
  194. Returns:
  195. str: 类别名称
  196. """
  197. if interpolated_value <= 1.5:
  198. return "优先保护类"
  199. elif interpolated_value <= 2.5:
  200. return "安全利用类"
  201. else:
  202. return "严格管控类"
  203. def get_unit_count(self) -> int:
  204. """
  205. 获取单元总数
  206. Returns:
  207. int: 单元总数
  208. """
  209. return self.db_session.query(UnitCeil).count()
  210. def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
  211. """
  212. 获取不同h_xtfx类别的点位数量统计
  213. Returns:
  214. Dict[str, int]: 各类别点位数量
  215. """
  216. try:
  217. result = self.db_session.query(
  218. FiftyThousandSurveyDatum.h_xtfx,
  219. func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
  220. ).filter(
  221. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  222. ).group_by(
  223. FiftyThousandSurveyDatum.h_xtfx
  224. ).all()
  225. return {row.h_xtfx: row.count for row in result}
  226. except Exception as e:
  227. logger.error(f"获取点位统计失败: {e}")
  228. return {}
  229. def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
  230. """
  231. 批量获取单元信息
  232. Args:
  233. unit_ids: 单元ID列表
  234. Returns:
  235. List[UnitCeil]: 单元对象列表
  236. """
  237. return self.db_session.query(UnitCeil).filter(
  238. UnitCeil.gid.in_(unit_ids)
  239. ).all()
  240. def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
  241. """
  242. 获取特定区域的点位数据
  243. Args:
  244. area_name: 区域名称(县名称)
  245. Returns:
  246. List[FiftyThousandSurveyDatum]: 点位数据列表
  247. """
  248. query = self.db_session.query(FiftyThousandSurveyDatum).filter(
  249. FiftyThousandSurveyDatum.h_xtfx.isnot(None)
  250. )
  251. if area_name:
  252. query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
  253. return query.all()
  254. def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
  255. """
  256. 优化的空间查询:找出包含指定点位的单元
  257. Args:
  258. point_ids: 点位ID列表
  259. Returns:
  260. List[Tuple[int, int]]: (unit_id, point_id) 对的列表
  261. """
  262. try:
  263. # 使用数据库端的空间查询
  264. spatial_query = text("""
  265. SELECT
  266. u.gid as unit_id,
  267. p.id as point_id
  268. FROM unit_ceil u
  269. JOIN fifty_thousand_survey_data p ON ST_Contains(
  270. u.geom,
  271. ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
  272. )
  273. WHERE p.id = ANY(:point_ids)
  274. """)
  275. result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
  276. return [(row.unit_id, row.point_id) for row in result]
  277. except Exception as e:
  278. logger.error(f"优化空间查询失败: {e}")
  279. return []
  280. def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
  281. """
  282. 获取各区域的统计信息
  283. Returns:
  284. Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
  285. """
  286. try:
  287. result = self.db_session.query(
  288. FiftyThousandSurveyDatum.xmc,
  289. FiftyThousandSurveyDatum.h_xtfx,
  290. func.count(FiftyThousandSurveyDatum.id).label('count')
  291. ).filter(
  292. FiftyThousandSurveyDatum.h_xtfx.isnot(None),
  293. FiftyThousandSurveyDatum.xmc.isnot(None)
  294. ).group_by(
  295. FiftyThousandSurveyDatum.xmc,
  296. FiftyThousandSurveyDatum.h_xtfx
  297. ).order_by(
  298. FiftyThousandSurveyDatum.xmc
  299. ).all()
  300. area_stats = {}
  301. for row in result:
  302. area_name = row.xmc
  303. h_xtfx = row.h_xtfx
  304. count = row.count
  305. if area_name not in area_stats:
  306. area_stats[area_name] = {}
  307. area_stats[area_name][h_xtfx] = count
  308. return area_stats
  309. except Exception as e:
  310. logger.error(f"获取区域统计信息失败: {e}")
  311. return {}
  312. def get_unit_h_xtfx_result(self) -> Dict[str, any]:
  313. """
  314. 获取单元h_xtfx结果的API方法,同时返回gid和OBJECTID
  315. Returns:
  316. Dict[str, any]: API响应数据,包含单元的gid和OBJECTID
  317. """
  318. try:
  319. # 计算单元结果
  320. unit_values = self.calculate_unit_h_xtfx_values()
  321. # 获取所有单元的gid和OBJECTID
  322. unit_ids = list(unit_values.keys())
  323. units = self.db_session.query(UnitCeil.gid, UnitCeil.OBJECTID).filter(
  324. UnitCeil.gid.in_(unit_ids)
  325. ).all()
  326. # 构建包含gid和OBJECTID的结果列表
  327. result_with_ids = []
  328. for unit in units:
  329. result_with_ids.append({
  330. "gid": unit.gid,
  331. "OBJECTID": unit.OBJECTID,
  332. "h_xtfx": unit_values.get(unit.gid)
  333. })
  334. # 统计信息
  335. total_units = self.get_unit_count()
  336. units_with_data = sum(1 for item in result_with_ids if item['h_xtfx'] is not None)
  337. category_stats = {key: 0 for key in self.H_XTFX_MAPPING.keys()}
  338. for item in result_with_ids:
  339. if item['h_xtfx'] in category_stats:
  340. category_stats[item['h_xtfx']] += 1
  341. point_stats = self.get_point_count_by_h_xtfx()
  342. return {
  343. "success": True,
  344. "data": result_with_ids, # 包含gid和OBJECTID的数组
  345. "statistics": {
  346. "total_units": total_units,
  347. "units_with_data": units_with_data,
  348. "units_without_data": total_units - units_with_data,
  349. "category_distribution": category_stats,
  350. "point_distribution": point_stats
  351. }
  352. }
  353. except Exception as e:
  354. logger.error(f"获取单元h_xtfx结果失败: {e}")
  355. return {
  356. "success": False,
  357. "error": str(e),
  358. "data": None
  359. }