from sqlalchemy.orm import Session from sqlalchemy import func, text from collections import Counter from typing import Dict, List, Tuple, Optional import logging from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil logger = logging.getLogger(__name__) class UnitGroupingService: """ 单元分组服务 提供基于点位数据的单元h_xtfx值计算功能 """ # 定义 h_xtfx 值的映射关系(数值用于插值计算) H_XTFX_MAPPING = { "优先保护类": 1, "安全利用类": 2, "严格管控类": 3 } REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()} def __init__(self, db_session: Session): """ 初始化服务 Args: db_session: 数据库会话 """ self.db_session = db_session def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]: """ 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值 Returns: Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射 """ try: # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值 spatial_query = text(""" SELECT u.gid as unit_id, p.h_xtfx, ST_X(ST_Centroid(u.geom)) as unit_center_x, ST_Y(ST_Centroid(u.geom)) as unit_center_y, ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x, ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y FROM unit_ceil u JOIN fifty_thousand_survey_data p ON ST_Contains( u.geom, ST_Transform(ST_SetSRID(p.geom, 4490), 4490) ) WHERE p.h_xtfx IS NOT NULL ORDER BY u.gid """) spatial_results = self.db_session.execute(spatial_query).fetchall() # 组织数据:单元ID -> 包含的点位列表 unit_points = {} for result in spatial_results: unit_id = result.unit_id h_xtfx = result.h_xtfx point_x = result.point_x point_y = result.point_y if unit_id not in unit_points: unit_points[unit_id] = [] unit_points[unit_id].append({ 'h_xtfx': h_xtfx, 'x': point_x, 'y': point_y }) # 计算每个单元的h_xtfx值 result = {} # 获取所有单元的ID all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall() for unit_row in all_units: unit_id = unit_row.gid points = unit_points.get(unit_id, []) if not points: result[unit_id] = None continue try: result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points) except Exception as e: logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}") result[unit_id] = None logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果") return result except Exception as e: logger.error(f"计算单元h_xtfx值失败: {e}") return {} def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]: """ 基于点位列表计算单个单元的h_xtfx值 Args: unit_id: 单元ID points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float} Returns: str: h_xtfx值 """ if not points: return None h_xtfx_list = [point['h_xtfx'] for point in points] has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list) if not has_strict_control: # 无严格管控类:先判断比例是否 ≥80% counter = Counter(h_xtfx_list) most_common, count = counter.most_common(1)[0] if count / len(points) >= 0.8: return most_common # 比例不达标:对优先保护类和安全利用类进行插值 valid_points = [ (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']]) for point in points if point['h_xtfx'] in ["优先保护类", "安全利用类"] ] if len(valid_points) < 2: # 有效点位不足,取最常见值 return most_common else: # 获取单元中心点坐标 unit_center = self._get_unit_center(unit_id) if unit_center is None: return most_common interpolated = self._idw_interpolation_simple(valid_points, unit_center) if interpolated is None: return most_common return self._interpolated_value_to_category(interpolated) else: # 存在严格管控类:对所有点位进行插值 all_points = [ (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']]) for point in points if point['h_xtfx'] in self.H_XTFX_MAPPING ] if not all_points: return None # 获取单元中心点坐标 unit_center = self._get_unit_center(unit_id) if unit_center is None: return None interpolated = self._idw_interpolation_simple(all_points, unit_center) if interpolated is None: return None return self._interpolated_value_to_category(interpolated) def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]: """ 获取单元中心点坐标 Args: unit_id: 单元ID Returns: Tuple[float, float]: (x, y) 坐标 """ try: center_query = text(""" SELECT ST_X(ST_Centroid(geom)) as center_x, ST_Y(ST_Centroid(geom)) as center_y FROM unit_ceil WHERE gid = :unit_id """) result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone() if result: return (result.center_x, result.center_y) return None except Exception as e: logger.error(f"获取单元 {unit_id} 中心点失败: {e}") return None def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]: """ 简单的反距离加权插值函数 Args: points: 点位坐标和值的列表 [(x, y, value), ...] target_point: 目标点位坐标 (x, y) Returns: float: 插值结果 """ if not points: return None total_weight = 0 weighted_sum = 0 power = 2 # 距离权重的幂 target_x, target_y = target_point for point_x, point_y, value in points: # 计算欧几里得距离 distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5 if distance == 0: return value # 距离为0时直接返回该点值 weight = 1 / (distance ** power) total_weight += weight weighted_sum += weight * value return weighted_sum / total_weight if total_weight != 0 else None def _interpolated_value_to_category(self, interpolated_value: float) -> str: """ 将插值结果转换为类别 Args: interpolated_value: 插值结果 Returns: str: 类别名称 """ if interpolated_value <= 1.5: return "优先保护类" elif interpolated_value <= 2.5: return "安全利用类" else: return "严格管控类" def get_unit_count(self) -> int: """ 获取单元总数 Returns: int: 单元总数 """ return self.db_session.query(UnitCeil).count() def get_point_count_by_h_xtfx(self) -> Dict[str, int]: """ 获取不同h_xtfx类别的点位数量统计 Returns: Dict[str, int]: 各类别点位数量 """ try: result = self.db_session.query( FiftyThousandSurveyDatum.h_xtfx, func.count(FiftyThousandSurveyDatum.h_xtfx).label('count') ).filter( FiftyThousandSurveyDatum.h_xtfx.isnot(None) ).group_by( FiftyThousandSurveyDatum.h_xtfx ).all() return {row.h_xtfx: row.count for row in result} except Exception as e: logger.error(f"获取点位统计失败: {e}") return {} def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]: """ 批量获取单元信息 Args: unit_ids: 单元ID列表 Returns: List[UnitCeil]: 单元对象列表 """ return self.db_session.query(UnitCeil).filter( UnitCeil.gid.in_(unit_ids) ).all() def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]: """ 获取特定区域的点位数据 Args: area_name: 区域名称(县名称) Returns: List[FiftyThousandSurveyDatum]: 点位数据列表 """ query = self.db_session.query(FiftyThousandSurveyDatum).filter( FiftyThousandSurveyDatum.h_xtfx.isnot(None) ) if area_name: query = query.filter(FiftyThousandSurveyDatum.xmc == area_name) return query.all() def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]: """ 优化的空间查询:找出包含指定点位的单元 Args: point_ids: 点位ID列表 Returns: List[Tuple[int, int]]: (unit_id, point_id) 对的列表 """ try: # 使用数据库端的空间查询 spatial_query = text(""" SELECT u.gid as unit_id, p.id as point_id FROM unit_ceil u JOIN fifty_thousand_survey_data p ON ST_Contains( u.geom, ST_Transform(ST_SetSRID(p.geom, 4490), 4490) ) WHERE p.id = ANY(:point_ids) """) result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall() return [(row.unit_id, row.point_id) for row in result] except Exception as e: logger.error(f"优化空间查询失败: {e}") return [] def get_area_statistics(self) -> Dict[str, Dict[str, int]]: """ 获取各区域的统计信息 Returns: Dict[str, Dict[str, int]]: 区域名称到统计信息的映射 """ try: result = self.db_session.query( FiftyThousandSurveyDatum.xmc, FiftyThousandSurveyDatum.h_xtfx, func.count(FiftyThousandSurveyDatum.id).label('count') ).filter( FiftyThousandSurveyDatum.h_xtfx.isnot(None), FiftyThousandSurveyDatum.xmc.isnot(None) ).group_by( FiftyThousandSurveyDatum.xmc, FiftyThousandSurveyDatum.h_xtfx ).order_by( FiftyThousandSurveyDatum.xmc ).all() area_stats = {} for row in result: area_name = row.xmc h_xtfx = row.h_xtfx count = row.count if area_name not in area_stats: area_stats[area_name] = {} area_stats[area_name][h_xtfx] = count return area_stats except Exception as e: logger.error(f"获取区域统计信息失败: {e}") return {} def get_unit_h_xtfx_result(self) -> Dict[str, any]: """ 获取单元h_xtfx结果的API方法,同时返回gid和OBJECTID Returns: Dict[str, any]: API响应数据,包含单元的gid和OBJECTID """ try: # 计算单元结果 unit_values = self.calculate_unit_h_xtfx_values() # 获取所有单元的gid和OBJECTID unit_ids = list(unit_values.keys()) units = self.db_session.query(UnitCeil.gid, UnitCeil.OBJECTID).filter( UnitCeil.gid.in_(unit_ids) ).all() # 构建包含gid和OBJECTID的结果列表 result_with_ids = [] for unit in units: result_with_ids.append({ "gid": unit.gid, "OBJECTID": unit.OBJECTID, "h_xtfx": unit_values.get(unit.gid) }) # 统计信息 total_units = self.get_unit_count() units_with_data = sum(1 for item in result_with_ids if item['h_xtfx'] is not None) category_stats = {key: 0 for key in self.H_XTFX_MAPPING.keys()} for item in result_with_ids: if item['h_xtfx'] in category_stats: category_stats[item['h_xtfx']] += 1 point_stats = self.get_point_count_by_h_xtfx() return { "success": True, "data": result_with_ids, # 包含gid和OBJECTID的数组 "statistics": { "total_units": total_units, "units_with_data": units_with_data, "units_without_data": total_units - units_with_data, "category_distribution": category_stats, "point_distribution": point_stats } } except Exception as e: logger.error(f"获取单元h_xtfx结果失败: {e}") return { "success": False, "error": str(e), "data": None }