123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- from sqlalchemy.orm import Session
- from sqlalchemy import func, text
- from collections import Counter
- from typing import Dict, List, Tuple, Optional
- import logging
- from ..models.orm_models import FiftyThousandSurveyDatum, UnitCeil
- logger = logging.getLogger(__name__)
- class UnitGroupingService:
- """
- 单元分组服务
-
- 提供基于点位数据的单元h_xtfx值计算功能
- """
-
- # 定义 h_xtfx 值的映射关系(数值用于插值计算)
- H_XTFX_MAPPING = {
- "优先保护类": 1,
- "安全利用类": 2,
- "严格管控类": 3
- }
-
- REVERSE_H_XTFX_MAPPING = {v: k for k, v in H_XTFX_MAPPING.items()}
-
- def __init__(self, db_session: Session):
- """
- 初始化服务
-
- Args:
- db_session: 数据库会话
- """
- self.db_session = db_session
-
- def calculate_unit_h_xtfx_values(self) -> Dict[int, Optional[str]]:
- """
- 核心逻辑:使用数据库端空间查询计算单元的 h_xtfx 值
-
- Returns:
- Dict[int, Optional[str]]: 单元ID到h_xtfx值的映射
- """
- try:
- # 直接在数据库中进行空间查询,获取每个单元包含的点位及其h_xtfx值
- spatial_query = text("""
- SELECT
- u.gid as unit_id,
- p.h_xtfx,
- ST_X(ST_Centroid(u.geom)) as unit_center_x,
- ST_Y(ST_Centroid(u.geom)) as unit_center_y,
- ST_X(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_x,
- ST_Y(ST_Transform(ST_SetSRID(p.geom, 4490), 4490)) as point_y
- FROM unit_ceil u
- JOIN fifty_thousand_survey_data p ON ST_Contains(
- u.geom,
- ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
- )
- WHERE p.h_xtfx IS NOT NULL
- ORDER BY u.gid
- """)
-
- spatial_results = self.db_session.execute(spatial_query).fetchall()
-
- # 组织数据:单元ID -> 包含的点位列表
- unit_points = {}
- for result in spatial_results:
- unit_id = result.unit_id
- h_xtfx = result.h_xtfx
- point_x = result.point_x
- point_y = result.point_y
-
- if unit_id not in unit_points:
- unit_points[unit_id] = []
-
- unit_points[unit_id].append({
- 'h_xtfx': h_xtfx,
- 'x': point_x,
- 'y': point_y
- })
-
- # 计算每个单元的h_xtfx值
- result = {}
-
- # 获取所有单元的ID
- all_units = self.db_session.execute(text("SELECT gid FROM unit_ceil ORDER BY gid")).fetchall()
-
- for unit_row in all_units:
- unit_id = unit_row.gid
- points = unit_points.get(unit_id, [])
-
- if not points:
- result[unit_id] = None
- continue
-
- try:
- result[unit_id] = self._calculate_single_unit_h_xtfx_from_points(unit_id, points)
- except Exception as e:
- logger.error(f"计算单元 {unit_id} 的h_xtfx值失败: {e}")
- result[unit_id] = None
-
- logger.info(f"计算完成,共处理 {len(result)} 个单元,其中 {sum(1 for v in result.values() if v is not None)} 个有有效结果")
- return result
-
- except Exception as e:
- logger.error(f"计算单元h_xtfx值失败: {e}")
- return {}
-
- def _calculate_single_unit_h_xtfx_from_points(self, unit_id: int, points: List[Dict]) -> Optional[str]:
- """
- 基于点位列表计算单个单元的h_xtfx值
-
- Args:
- unit_id: 单元ID
- points: 单元内的点位列表,每个点位包含 {'h_xtfx': str, 'x': float, 'y': float}
-
- Returns:
- str: h_xtfx值
- """
- if not points:
- return None
-
- h_xtfx_list = [point['h_xtfx'] for point in points]
- has_strict_control = any(h_xtfx == "严格管控类" for h_xtfx in h_xtfx_list)
-
- if not has_strict_control:
- # 无严格管控类:先判断比例是否 ≥80%
- counter = Counter(h_xtfx_list)
- most_common, count = counter.most_common(1)[0]
- if count / len(points) >= 0.8:
- return most_common
-
- # 比例不达标:对优先保护类和安全利用类进行插值
- valid_points = [
- (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
- for point in points
- if point['h_xtfx'] in ["优先保护类", "安全利用类"]
- ]
-
- if len(valid_points) < 2:
- # 有效点位不足,取最常见值
- return most_common
- else:
- # 获取单元中心点坐标
- unit_center = self._get_unit_center(unit_id)
- if unit_center is None:
- return most_common
-
- interpolated = self._idw_interpolation_simple(valid_points, unit_center)
- if interpolated is None:
- return most_common
- return self._interpolated_value_to_category(interpolated)
- else:
- # 存在严格管控类:对所有点位进行插值
- all_points = [
- (point['x'], point['y'], self.H_XTFX_MAPPING[point['h_xtfx']])
- for point in points
- if point['h_xtfx'] in self.H_XTFX_MAPPING
- ]
-
- if not all_points:
- return None
-
- # 获取单元中心点坐标
- unit_center = self._get_unit_center(unit_id)
- if unit_center is None:
- return None
-
- interpolated = self._idw_interpolation_simple(all_points, unit_center)
- if interpolated is None:
- return None
- return self._interpolated_value_to_category(interpolated)
-
- def _get_unit_center(self, unit_id: int) -> Optional[Tuple[float, float]]:
- """
- 获取单元中心点坐标
-
- Args:
- unit_id: 单元ID
-
- Returns:
- Tuple[float, float]: (x, y) 坐标
- """
- try:
- center_query = text("""
- SELECT
- ST_X(ST_Centroid(geom)) as center_x,
- ST_Y(ST_Centroid(geom)) as center_y
- FROM unit_ceil
- WHERE gid = :unit_id
- """)
-
- result = self.db_session.execute(center_query, {"unit_id": unit_id}).fetchone()
- if result:
- return (result.center_x, result.center_y)
- return None
-
- except Exception as e:
- logger.error(f"获取单元 {unit_id} 中心点失败: {e}")
- return None
-
- def _idw_interpolation_simple(self, points: List[Tuple], target_point: Tuple[float, float]) -> Optional[float]:
- """
- 简单的反距离加权插值函数
-
- Args:
- points: 点位坐标和值的列表 [(x, y, value), ...]
- target_point: 目标点位坐标 (x, y)
-
- Returns:
- float: 插值结果
- """
- if not points:
- return None
-
- total_weight = 0
- weighted_sum = 0
- power = 2 # 距离权重的幂
-
- target_x, target_y = target_point
-
- for point_x, point_y, value in points:
- # 计算欧几里得距离
- distance = ((point_x - target_x) ** 2 + (point_y - target_y) ** 2) ** 0.5
-
- if distance == 0:
- return value # 距离为0时直接返回该点值
-
- weight = 1 / (distance ** power)
- total_weight += weight
- weighted_sum += weight * value
-
- return weighted_sum / total_weight if total_weight != 0 else None
-
- def _interpolated_value_to_category(self, interpolated_value: float) -> str:
- """
- 将插值结果转换为类别
-
- Args:
- interpolated_value: 插值结果
-
- Returns:
- str: 类别名称
- """
- if interpolated_value <= 1.5:
- return "优先保护类"
- elif interpolated_value <= 2.5:
- return "安全利用类"
- else:
- return "严格管控类"
-
- def get_unit_count(self) -> int:
- """
- 获取单元总数
-
- Returns:
- int: 单元总数
- """
- return self.db_session.query(UnitCeil).count()
-
- def get_point_count_by_h_xtfx(self) -> Dict[str, int]:
- """
- 获取不同h_xtfx类别的点位数量统计
-
- Returns:
- Dict[str, int]: 各类别点位数量
- """
- try:
- result = self.db_session.query(
- FiftyThousandSurveyDatum.h_xtfx,
- func.count(FiftyThousandSurveyDatum.h_xtfx).label('count')
- ).filter(
- FiftyThousandSurveyDatum.h_xtfx.isnot(None)
- ).group_by(
- FiftyThousandSurveyDatum.h_xtfx
- ).all()
-
- return {row.h_xtfx: row.count for row in result}
-
- except Exception as e:
- logger.error(f"获取点位统计失败: {e}")
- return {}
-
- def get_units_by_ids(self, unit_ids: List[int]) -> List[UnitCeil]:
- """
- 批量获取单元信息
-
- Args:
- unit_ids: 单元ID列表
-
- Returns:
- List[UnitCeil]: 单元对象列表
- """
- return self.db_session.query(UnitCeil).filter(
- UnitCeil.gid.in_(unit_ids)
- ).all()
-
- def get_points_in_area(self, area_name: str = None) -> List[FiftyThousandSurveyDatum]:
- """
- 获取特定区域的点位数据
-
- Args:
- area_name: 区域名称(县名称)
-
- Returns:
- List[FiftyThousandSurveyDatum]: 点位数据列表
- """
- query = self.db_session.query(FiftyThousandSurveyDatum).filter(
- FiftyThousandSurveyDatum.h_xtfx.isnot(None)
- )
-
- if area_name:
- query = query.filter(FiftyThousandSurveyDatum.xmc == area_name)
-
- return query.all()
-
- def get_units_containing_points_optimized(self, point_ids: List[int]) -> List[Tuple[int, int]]:
- """
- 优化的空间查询:找出包含指定点位的单元
-
- Args:
- point_ids: 点位ID列表
-
- Returns:
- List[Tuple[int, int]]: (unit_id, point_id) 对的列表
- """
- try:
- # 使用数据库端的空间查询
- spatial_query = text("""
- SELECT
- u.gid as unit_id,
- p.id as point_id
- FROM unit_ceil u
- JOIN fifty_thousand_survey_data p ON ST_Contains(
- u.geom,
- ST_Transform(ST_SetSRID(p.geom, 4490), 4490)
- )
- WHERE p.id = ANY(:point_ids)
- """)
-
- result = self.db_session.execute(spatial_query, {"point_ids": point_ids}).fetchall()
- return [(row.unit_id, row.point_id) for row in result]
-
- except Exception as e:
- logger.error(f"优化空间查询失败: {e}")
- return []
-
- def get_area_statistics(self) -> Dict[str, Dict[str, int]]:
- """
- 获取各区域的统计信息
-
- Returns:
- Dict[str, Dict[str, int]]: 区域名称到统计信息的映射
- """
- try:
- result = self.db_session.query(
- FiftyThousandSurveyDatum.xmc,
- FiftyThousandSurveyDatum.h_xtfx,
- func.count(FiftyThousandSurveyDatum.id).label('count')
- ).filter(
- FiftyThousandSurveyDatum.h_xtfx.isnot(None),
- FiftyThousandSurveyDatum.xmc.isnot(None)
- ).group_by(
- FiftyThousandSurveyDatum.xmc,
- FiftyThousandSurveyDatum.h_xtfx
- ).order_by(
- FiftyThousandSurveyDatum.xmc
- ).all()
-
- area_stats = {}
- for row in result:
- area_name = row.xmc
- h_xtfx = row.h_xtfx
- count = row.count
-
- if area_name not in area_stats:
- area_stats[area_name] = {}
-
- area_stats[area_name][h_xtfx] = count
-
- return area_stats
-
- except Exception as e:
- logger.error(f"获取区域统计信息失败: {e}")
- return {}
-
- def get_unit_h_xtfx_result(self) -> Dict[str, any]:
- """
- 获取单元h_xtfx结果的API方法
-
- Returns:
- Dict[str, any]: API响应数据
- """
- try:
- result = self.calculate_unit_h_xtfx_values()
-
- total_units = self.get_unit_count()
- units_with_data = sum(1 for v in result.values() if v is not None)
-
- # 按类别统计
- category_stats = {}
- for category in self.H_XTFX_MAPPING.keys():
- category_stats[category] = sum(1 for v in result.values() if v == category)
-
- point_stats = self.get_point_count_by_h_xtfx()
-
- return {
- "success": True,
- "data": result,
- "statistics": {
- "total_units": total_units,
- "units_with_data": units_with_data,
- "units_without_data": total_units - units_with_data,
- "category_distribution": category_stats,
- "point_distribution": point_stats
- }
- }
-
- except Exception as e:
- logger.error(f"获取单元h_xtfx结果失败: {e}")
- return {
- "success": False,
- "error": str(e),
- "data": None
- }
|