Browse Source

添加行政层级参数到籽粒和秸秆移除Cd通量可视化接口,增强服务逻辑以支持严格匹配层级,更新边界文件获取逻辑以确保精确匹配。

drggboy 1 week ago
parent
commit
f7b930054d
2 changed files with 99 additions and 57 deletions
  1. 16 4
      app/api/cd_flux_removal.py
  2. 83 53
      app/services/cd_flux_removal_service.py

+ 16 - 4
app/api/cd_flux_removal.py

@@ -136,6 +136,7 @@ async def calculate_straw_removal(
            description="计算籽粒移除Cd通量并生成栅格地图并返回图片文件")
 async def visualize_grain_removal(
     area: str = Query(..., description="地区名称,如:韶关"),
+    level: str = Query(..., description="行政层级,必须为: county | city | province"),
     colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
     resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
     enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
@@ -154,9 +155,13 @@ async def visualize_grain_removal(
     """
     try:
         service = CdFluxRemovalService()
+
+        # 行政层级校验(不允许模糊)
+        if level not in ("county", "city", "province"):
+            raise HTTPException(status_code=400, detail="参数 level 必须为 'county' | 'city' | 'province'")
         
-        # 计算籽粒移除Cd通量
-        calc_result = service.calculate_grain_removal_by_area(area)
+        # 计算籽粒移除Cd通量(传入严格层级以便参数表查找做后缀标准化精确匹配)
+        calc_result = service.calculate_grain_removal_by_area(area, level=level)
         
         if not calc_result["success"]:
             raise HTTPException(
@@ -176,6 +181,7 @@ async def visualize_grain_removal(
         # 创建可视化
         visualization_files = service.create_flux_visualization(
             area=area,
+            level=level,
             calculation_type="grain_removal",
             results_with_coords=results_with_coords,
             colormap=colormap,
@@ -210,6 +216,7 @@ async def visualize_grain_removal(
            description="计算秸秆移除Cd通量并生成栅格地图并返回图片文件")
 async def visualize_straw_removal(
     area: str = Query(..., description="地区名称,如:韶关"),
+    level: str = Query(..., description="行政层级,必须为: county | city | province"),
     colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
     resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
     enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
@@ -228,9 +235,13 @@ async def visualize_straw_removal(
     """
     try:
         service = CdFluxRemovalService()
+
+        # 行政层级校验(不允许模糊)
+        if level not in ("county", "city", "province"):
+            raise HTTPException(status_code=400, detail="参数 level 必须为 'county' | 'city' | 'province'")
         
-        # 计算秸秆移除Cd通量
-        calc_result = service.calculate_straw_removal_by_area(area)
+        # 计算秸秆移除Cd通量(传入严格层级以便参数表查找做后缀标准化精确匹配)
+        calc_result = service.calculate_straw_removal_by_area(area, level=level)
         
         if not calc_result["success"]:
             raise HTTPException(
@@ -250,6 +261,7 @@ async def visualize_straw_removal(
         # 创建可视化
         visualization_files = service.create_flux_visualization(
             area=area,
+            level=level,
             calculation_type="straw_removal",
             results_with_coords=results_with_coords,
             colormap=colormap,

+ 83 - 53
app/services/cd_flux_removal_service.py

@@ -19,6 +19,8 @@ from ..models.CropCd_output import CropCdOutputData
 from ..models.farmland import FarmlandData
 from ..utils.mapping_utils import MappingUtils
 from .admin_boundary_service import get_boundary_geojson_by_name
+import geopandas as gpd
+from shapely.geometry import shape, Point
 import tempfile
 import json
 
@@ -35,8 +37,9 @@ class CdFluxRemovalService:
         初始化Cd通量移除服务
         """
         self.logger = logging.getLogger(__name__)
+        # 严格匹配策略:不做名称变体或后缀映射
         
-    def calculate_grain_removal_by_area(self, area: str) -> Dict[str, Any]:
+    def calculate_grain_removal_by_area(self, area: str, level: Optional[str] = None) -> Dict[str, Any]:
         """
         根据地区计算籽粒移除Cd通量
         
@@ -47,7 +50,7 @@ class CdFluxRemovalService:
         """
         try:
             with SessionLocal() as db:
-                # 查询指定地区的参数
+                # 查询指定地区的参数(严格等号匹配,不做任何映射)
                 parameter = db.query(Parameters).filter(Parameters.area == area).first()
                 
                 if not parameter:
@@ -112,7 +115,7 @@ class CdFluxRemovalService:
                 "data": None
             }
     
-    def calculate_straw_removal_by_area(self, area: str) -> Dict[str, Any]:
+    def calculate_straw_removal_by_area(self, area: str, level: Optional[str] = None) -> Dict[str, Any]:
         """
         根据地区计算秸秆移除Cd通量
         
@@ -123,7 +126,7 @@ class CdFluxRemovalService:
         """
         try:
             with SessionLocal() as db:
-                # 查询指定地区的参数
+                # 查询指定地区的参数(严格等号匹配,不做任何映射)
                 parameter = db.query(Parameters).filter(Parameters.area == area).first()
                 
                 if not parameter:
@@ -305,6 +308,7 @@ class CdFluxRemovalService:
     
     def create_flux_visualization(self, area: str, calculation_type: str,
                                 results_with_coords: List[Dict[str, Any]],
+                                level: str = None,
                                 output_dir: str = "app/static/cd_flux",
                                 template_raster: str = "app/static/cd_flux/meanTemp.tif",
                                 boundary_shp: str = None,
@@ -332,6 +336,7 @@ class CdFluxRemovalService:
             
             # 确保输出目录存在
             os.makedirs(output_dir, exist_ok=True)
+            generated_files: Dict[str, str] = {}
             
             # 生成时间戳
             timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -365,17 +370,75 @@ class CdFluxRemovalService:
                 self.logger.warning(f"模板栅格文件不存在: {template_raster}")
                 template_raster = None
             
-            # 动态获取边界文件
-            boundary_shp = self._get_boundary_file_for_area(area)
+            # 动态获取边界文件(严格使用指定层级)
+            if not level:
+                raise ValueError("必须提供行政层级 level:county | city | province")
+            boundary_shp = self._get_boundary_file_for_area(area, level)
             if not boundary_shp:
                 self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪")
+            else:
+                # 在绘图前进行样点边界包含性统计
+                try:
+                    boundary_gdf = gpd.read_file(boundary_shp)
+                    if boundary_gdf is not None and len(boundary_gdf) > 0:
+                        boundary_union = boundary_gdf.unary_union
+                        total_points = len(results_with_coords)
+                        inside_count = 0
+                        outside_points: List[Dict[str, Any]] = []
+                        for r in results_with_coords:
+                            pt = Point(float(r["longitude"]), float(r["latitude"]))
+                            if boundary_union.contains(pt) or boundary_union.touches(pt):
+                                inside_count += 1
+                            else:
+                                outside_points.append({
+                                    "farmland_id": r.get("farmland_id"),
+                                    "sample_id": r.get("sample_id"),
+                                    "longitude": r.get("longitude"),
+                                    "latitude": r.get("latitude"),
+                                    "flux_value": r.get("flux_value")
+                                })
+
+                        outside_count = total_points - inside_count
+                        inside_pct = (inside_count / total_points * 100.0) if total_points > 0 else 0.0
+
+                        self.logger.info(
+                            f"样点边界检查 - 总数: {total_points}, 边界内: {inside_count} ({inside_pct:.2f}%), 边界外: {outside_count}")
+                        if outside_count > 0:
+                            sample_preview = outside_points[:10]
+                            self.logger.warning(
+                                f"存在 {outside_count} 个样点位于边界之外,绘图时将被掩膜隐藏。示例(最多10条): {sample_preview}")
+
+                        report = {
+                            "area": area,
+                            "level": level,
+                            "calculation_type": calculation_type,
+                            "total_points": total_points,
+                            "inside_points": inside_count,
+                            "outside_points": outside_count,
+                            "inside_percentage": round(inside_pct, 2),
+                            "outside_samples": outside_points
+                        }
+                        os.makedirs(output_dir, exist_ok=True)
+                        report_path = os.path.join(
+                            output_dir,
+                            f"{calculation_type}_{area}_points_boundary_check_{timestamp}.json"
+                        )
+                        with open(report_path, "w", encoding="utf-8") as f:
+                            json.dump(report, f, ensure_ascii=False, indent=2)
+                        generated_files["point_boundary_report"] = report_path
+                    else:
+                        generated_files = {}
+                except Exception as check_err:
+                    self.logger.warning(f"样点边界包含性检查失败: {str(check_err)}")
+                    # 保持已有 generated_files,不覆盖
             
             # 创建shapefile
             shapefile_path = csv_path.replace('.csv', '_points.shp')
             mapper.csv_to_shapefile(csv_path, shapefile_path, 
                                   lon_col='longitude', lat_col='latitude', value_col='flux_value')
             
-            generated_files = {"csv": csv_path, "shapefile": shapefile_path}
+            # 合并已生成文件映射
+            generated_files.update({"csv": csv_path, "shapefile": shapefile_path})
             
             # 如果有模板栅格文件,创建栅格地图
             if template_raster:
@@ -482,7 +545,7 @@ class CdFluxRemovalService:
             except Exception:
                 pass
     
-    def _get_boundary_file_for_area(self, area: str) -> Optional[str]:
+    def _get_boundary_file_for_area(self, area: str, level: str) -> Optional[str]:
         """
         为指定地区获取边界文件
         
@@ -490,25 +553,8 @@ class CdFluxRemovalService:
         @returns: 边界文件路径或None
         """
         try:
-            # 首先尝试静态文件路径(只查找该地区专用的边界文件)
-            norm_area = area.strip()
-            base_name = norm_area.replace('市', '').replace('县', '')
-            name_variants = list(dict.fromkeys([
-                norm_area,
-                base_name,
-                f"{base_name}市",
-            ]))
-            static_boundary_paths = []
-            for name in name_variants:
-                static_boundary_paths.append(f"app/static/cd_flux/{name}.shp")
-            
-            for path in static_boundary_paths:
-                if os.path.exists(path):
-                    self.logger.info(f"找到边界文件: {path}")
-                    return path
-            
-            # 优先从数据库获取边界数据(对名称进行多变体匹配,如 “韶关/韶关市”)
-            boundary_path = self._create_boundary_from_database(area)
+            # 仅从数据库严格获取边界(按指定层级精确匹配)
+            boundary_path = self._create_boundary_from_database(area, level)
             if boundary_path:
                 return boundary_path
             
@@ -520,7 +566,7 @@ class CdFluxRemovalService:
             self.logger.error(f"获取边界文件失败: {str(e)}")
             return None
     
-    def _create_boundary_from_database(self, area: str) -> Optional[str]:
+    def _create_boundary_from_database(self, area: str, level: str) -> Optional[str]:
         """
         从数据库获取边界数据并创建临时shapefile
         
@@ -529,32 +575,16 @@ class CdFluxRemovalService:
         """
         try:
             with SessionLocal() as db:
-                # 生成名称变体,增强匹配鲁棒性
                 norm_area = area.strip()
-                base_name = norm_area.replace('市', '').replace('县', '')
-                candidates = list(dict.fromkeys([
-                    norm_area,
-                    base_name,
-                    f"{base_name}市",
-                ]))
-
-                for candidate in candidates:
-                    try:
-                        boundary_geojson = get_boundary_geojson_by_name(db, candidate, level="auto")
-                        if boundary_geojson:
-                            # 创建临时shapefile
-                            import geopandas as gpd
-                            from shapely.geometry import shape
-                            geometry = shape(boundary_geojson["geometry"])
-                            gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry], crs="EPSG:4326")
-                            temp_dir = tempfile.mkdtemp()
-                            boundary_path = os.path.join(temp_dir, f"{candidate}_boundary.shp")
-                            gdf.to_file(boundary_path, driver="ESRI Shapefile")
-                            self.logger.info(f"从数据库创建边界文件: {boundary_path}")
-                            return boundary_path
-                    except Exception as _:
-                        # 尝试下一个候选名称
-                        continue
+                boundary_geojson = get_boundary_geojson_by_name(db, norm_area, level=level)
+                if boundary_geojson:
+                    geometry_obj = shape(boundary_geojson["geometry"])
+                    gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry_obj], crs="EPSG:4326")
+                    temp_dir = tempfile.mkdtemp()
+                    boundary_path = os.path.join(temp_dir, f"{norm_area}_boundary.shp")
+                    gdf.to_file(boundary_path, driver="ESRI Shapefile")
+                    self.logger.info(f"从数据库创建边界文件: {boundary_path}")
+                    return boundary_path
                     
         except Exception as e:
             self.logger.warning(f"从数据库创建边界文件失败: {str(e)}")