|
@@ -0,0 +1,423 @@
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from sqlalchemy import func, text
|
|
|
+from collections import Counter
|
|
|
+from typing import Dict, List, Tuple, Optional
|
|
|
+import logging
|
|
|
+from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+class UnitGroupingService:
|
|
|
+ """
|
|
|
+ 单元分组服务
|
|
|
+
|
|
|
+ 提供基于点位数据的单元h_xtfx值计算功能
|
|
|
+ """
|
|
|
+
|
|
|
+ # 定义 h_xtfx 值的映射关系(数值用于插值计算)
|
|
|
+ H_XTFX_MAPPING = {
|
|
|
+ "优先保护类": 1,
|
|
|
+ "安全利用类": 2,
|
|
|
+ "严格管控类": 3
|
|
|
+ }
|
|
|
+
|
|
|
+ REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
|
|
|
+
|
|
|
+ def __init__(self, db_session: Session):
|
|
|
+ """
|
|
|
+ 初始化服务
|
|
|
+
|
|
|
+ Args:
|
|
|
+ db_session: 数据库会话
|
|
|
+ """
|
|
|
+ self.db_session = db_session
|
|
|
+
|
|
|
+ def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
|
|
|
+ """
|
|
|
+ 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
|
|
|
+ spatial_query = text("""
|
|
|
+ SELECT
|
|
|
+ u.gid as unit_id,
|
|
|
+ p.h_xtfx,
|
|
|
+ ST_X(ST_Centroid(u.geom)) as unit_center_x,
|
|
|
+ ST_Y(ST_Centroid(u.geom)) as unit_center_y,
|
|
|
+ ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
|
|
|
+ ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
|
|
|
+ FROM unit_ceil u
|
|
|
+ JOIN fifty_thousand_survey_data p ON ST_Contains(
|
|
|
+ u.geom,
|
|
|
+ ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
|
|
|
+ )
|
|
|
+ WHERE p.h_xtfx IS NOT NULL
|
|
|
+ ORDER BY u.gid
|
|
|
+ """)
|
|
|
+
|
|
|
+ spatial_results = self.db_session.execute(spatial_query).fetchall()
|
|
|
+
|
|
|
+ # 组织数据:单元ID -> 包含的点位列表
|
|
|
+ unit_points = {}
|
|
|
+ for result in spatial_results:
|
|
|
+ unit_id = result.unit_id
|
|
|
+ h_xtfx = result.h_xtfx
|
|
|
+ point_x = result.point_x
|
|
|
+ point_y = result.point_y
|
|
|
+
|
|
|
+ if unit_id not in unit_points:
|
|
|
+ unit_points[unit_id] = []
|
|
|
+
|
|
|
+ unit_points[unit_id].append({
|
|
|
+ 'h_xtfx': h_xtfx,
|
|
|
+ 'x': point_x,
|
|
|
+ 'y': point_y
|
|
|
+ })
|
|
|
+
|
|
|
+ # 计算每个单元的h_xtfx值
|
|
|
+ result = {}
|
|
|
+
|
|
|
+ # 获取所有单元的ID
|
|
|
+ all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
|
|
|
+
|
|
|
+ for unit_row in all_units:
|
|
|
+ unit_id = unit_row.gid
|
|
|
+ points = unit_points.get(unit_id, [])
|
|
|
+
|
|
|
+ if not points:
|
|
|
+ result[unit_id] = None
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
|
|
|
+ result[unit_id] = None
|
|
|
+
|
|
|
+ logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"计算单元h_xtfx值失败: {e}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 基于点位列表计算单个单元的h_xtfx值
|
|
|
+
|
|
|
+ Args:
|
|
|
+ unit_id: 单元ID
|
|
|
+ points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: h_xtfx值
|
|
|
+ """
|
|
|
+ if not points:
|
|
|
+ return None
|
|
|
+
|
|
|
+ h_xtfx_list = [point['h_xtfx'] for point in points]
|
|
|
+ has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
|
|
|
+
|
|
|
+ if not has_strict_control:
|
|
|
+ # 无严格管控类:先判断比例是否 ≥80%
|
|
|
+ counter = Counter(h_xtfx_list)
|
|
|
+ most_common, count = counter.most_common(1)[0]
|
|
|
+ if count / len(points) >= 0.8:
|
|
|
+ return most_common
|
|
|
+
|
|
|
+ # 比例不达标:对优先保护类和安全利用类进行插值
|
|
|
+ valid_points = [
|
|
|
+ (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
|
|
|
+ for point in points
|
|
|
+ if point['h_xtfx'] in ["优先保护类", "安全利用类"]
|
|
|
+ ]
|
|
|
+
|
|
|
+ if len(valid_points) < 2:
|
|
|
+ # 有效点位不足,取最常见值
|
|
|
+ return most_common
|
|
|
+ else:
|
|
|
+ # 获取单元中心点坐标
|
|
|
+ unit_center = self._get_unit_center(unit_id)
|
|
|
+ if unit_center is None:
|
|
|
+ return most_common
|
|
|
+
|
|
|
+ interpolated = self._idw_interpolation_simple(valid_points, unit_center)
|
|
|
+ if interpolated is None:
|
|
|
+ return most_common
|
|
|
+ return self._interpolated_value_to_category(interpolated)
|
|
|
+ else:
|
|
|
+ # 存在严格管控类:对所有点位进行插值
|
|
|
+ all_points = [
|
|
|
+ (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
|
|
|
+ for point in points
|
|
|
+ if point['h_xtfx'] in self.H_XTFX_MAPPING
|
|
|
+ ]
|
|
|
+
|
|
|
+ if not all_points:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 获取单元中心点坐标
|
|
|
+ unit_center = self._get_unit_center(unit_id)
|
|
|
+ if unit_center is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ interpolated = self._idw_interpolation_simple(all_points, unit_center)
|
|
|
+ if interpolated is None:
|
|
|
+ return None
|
|
|
+ return self._interpolated_value_to_category(interpolated)
|
|
|
+
|
|
|
+ def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
|
|
|
+ """
|
|
|
+ 获取单元中心点坐标
|
|
|
+
|
|
|
+ Args:
|
|
|
+ unit_id: 单元ID
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[float, float]: (x, y) 坐标
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ center_query = text("""
|
|
|
+ SELECT
|
|
|
+ ST_X(ST_Centroid(geom)) as center_x,
|
|
|
+ ST_Y(ST_Centroid(geom)) as center_y
|
|
|
+ FROM unit_ceil
|
|
|
+ WHERE gid = :unit_id
|
|
|
+ """)
|
|
|
+
|
|
|
+ result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
|
|
|
+ if result:
|
|
|
+ return (result.center_x, result.center_y)
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
|
|
|
+ """
|
|
|
+ 简单的反距离加权插值函数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ points: 点位坐标和值的列表 [(x, y, value), ...]
|
|
|
+ target_point: 目标点位坐标 (x, y)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: 插值结果
|
|
|
+ """
|
|
|
+ if not points:
|
|
|
+ return None
|
|
|
+
|
|
|
+ total_weight = 0
|
|
|
+ weighted_sum = 0
|
|
|
+ power = 2 # 距离权重的幂
|
|
|
+
|
|
|
+ target_x, target_y = target_point
|
|
|
+
|
|
|
+ for point_x, point_y, value in points:
|
|
|
+ # 计算欧几里得距离
|
|
|
+ distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
|
|
|
+
|
|
|
+ if distance == 0:
|
|
|
+ return value # 距离为0时直接返回该点值
|
|
|
+
|
|
|
+ weight = 1 / (distance ** power)
|
|
|
+ total_weight += weight
|
|
|
+ weighted_sum += weight * value
|
|
|
+
|
|
|
+ return weighted_sum / total_weight if total_weight != 0 else None
|
|
|
+
|
|
|
+ def _interpolated_value_to_category(self, interpolated_value: float) -> str:
|
|
|
+ """
|
|
|
+ 将插值结果转换为类别
|
|
|
+
|
|
|
+ Args:
|
|
|
+ interpolated_value: 插值结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: 类别名称
|
|
|
+ """
|
|
|
+ if interpolated_value <= 1.5:
|
|
|
+ return "优先保护类"
|
|
|
+ elif interpolated_value <= 2.5:
|
|
|
+ return "安全利用类"
|
|
|
+ else:
|
|
|
+ return "严格管控类"
|
|
|
+
|
|
|
+ def get_unit_count(self) -> int:
|
|
|
+ """
|
|
|
+ 获取单元总数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int: 单元总数
|
|
|
+ """
|
|
|
+ return self.db_session.query(UnitCeil).count()
|
|
|
+
|
|
|
+ def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
|
|
|
+ """
|
|
|
+ 获取不同h_xtfx类别的点位数量统计
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict[str, int]: 各类别点位数量
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ result = self.db_session.query(
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx,
|
|
|
+ func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
|
|
|
+ ).filter(
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx.isnot(None)
|
|
|
+ ).group_by(
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ return {row.h_xtfx: row.count for row in result}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取点位统计失败: {e}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
|
|
|
+ """
|
|
|
+ 批量获取单元信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ unit_ids: 单元ID列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[UnitCeil]: 单元对象列表
|
|
|
+ """
|
|
|
+ return self.db_session.query(UnitCeil).filter(
|
|
|
+ UnitCeil.gid.in_(unit_ids)
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
|
|
|
+ """
|
|
|
+ 获取特定区域的点位数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ area_name: 区域名称(县名称)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[FiftyThousandSurveyDatum]: 点位数据列表
|
|
|
+ """
|
|
|
+ query = self.db_session.query(FiftyThousandSurveyDatum).filter(
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx.isnot(None)
|
|
|
+ )
|
|
|
+
|
|
|
+ if area_name:
|
|
|
+ query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
|
|
|
+
|
|
|
+ return query.all()
|
|
|
+
|
|
|
+ def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
|
|
|
+ """
|
|
|
+ 优化的空间查询:找出包含指定点位的单元
|
|
|
+
|
|
|
+ Args:
|
|
|
+ point_ids: 点位ID列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Tuple[int, int]]: (unit_id, point_id) 对的列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 使用数据库端的空间查询
|
|
|
+ spatial_query = text("""
|
|
|
+ SELECT
|
|
|
+ u.gid as unit_id,
|
|
|
+ p.id as point_id
|
|
|
+ FROM unit_ceil u
|
|
|
+ JOIN fifty_thousand_survey_data p ON ST_Contains(
|
|
|
+ u.geom,
|
|
|
+ ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
|
|
|
+ )
|
|
|
+ WHERE p.id = ANY(:point_ids)
|
|
|
+ """)
|
|
|
+
|
|
|
+ result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
|
|
|
+ return [(row.unit_id, row.point_id) for row in result]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"优化空间查询失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
|
|
|
+ """
|
|
|
+ 获取各区域的统计信息
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ result = self.db_session.query(
|
|
|
+ FiftyThousandSurveyDatum.xmc,
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx,
|
|
|
+ func.count(FiftyThousandSurveyDatum.id).label('count')
|
|
|
+ ).filter(
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx.isnot(None),
|
|
|
+ FiftyThousandSurveyDatum.xmc.isnot(None)
|
|
|
+ ).group_by(
|
|
|
+ FiftyThousandSurveyDatum.xmc,
|
|
|
+ FiftyThousandSurveyDatum.h_xtfx
|
|
|
+ ).order_by(
|
|
|
+ FiftyThousandSurveyDatum.xmc
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ area_stats = {}
|
|
|
+ for row in result:
|
|
|
+ area_name = row.xmc
|
|
|
+ h_xtfx = row.h_xtfx
|
|
|
+ count = row.count
|
|
|
+
|
|
|
+ if area_name not in area_stats:
|
|
|
+ area_stats[area_name] = {}
|
|
|
+
|
|
|
+ area_stats[area_name][h_xtfx] = count
|
|
|
+
|
|
|
+ return area_stats
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取区域统计信息失败: {e}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ def get_unit_h_xtfx_result(self) -> Dict[str, any]:
|
|
|
+ """
|
|
|
+ 获取单元h_xtfx结果的API方法
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict[str, any]: API响应数据
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ result = self.calculate_unit_h_xtfx_values()
|
|
|
+
|
|
|
+ total_units = self.get_unit_count()
|
|
|
+ units_with_data = sum(1 for v in result.values() if v is not None)
|
|
|
+
|
|
|
+ # 按类别统计
|
|
|
+ category_stats = {}
|
|
|
+ for category in self.H_XTFX_MAPPING.keys():
|
|
|
+ category_stats[category] = sum(1 for v in result.values() if v == category)
|
|
|
+
|
|
|
+ point_stats = self.get_point_count_by_h_xtfx()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "data": result,
|
|
|
+ "statistics": {
|
|
|
+ "total_units": total_units,
|
|
|
+ "units_with_data": units_with_data,
|
|
|
+ "units_without_data": total_units - units_with_data,
|
|
|
+ "category_distribution": category_stats,
|
|
|
+ "point_distribution": point_stats
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取单元h_xtfx结果失败: {e}")
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "error": str(e),
|
|
|
+ "data": None
|
|
|
+ }
|