Prechádzať zdrojové kódy

添加籽粒和秸秆移除Cd通量可视化接口,支持生成栅格地图和导出计算数据为CSV文件;更新服务逻辑以处理可视化和数据导出功能。

drggboy 1 týždeň pred
rodič
commit
2cb8131240

+ 0 - 1
.gitignore

@@ -13,4 +13,3 @@ myenv/
 .vscode/
 .vscode/
 *.log
 *.log
 Cd_Prediction_Integrated_System/output/raster/meanTemp.tif.aux.xml
 Cd_Prediction_Integrated_System/output/raster/meanTemp.tif.aux.xml
-config.env

+ 227 - 0
app/api/cd_flux_removal.py

@@ -4,9 +4,11 @@ Cd通量移除计算API接口
 """
 """
 
 
 from fastapi import APIRouter, HTTPException, Query, Path
 from fastapi import APIRouter, HTTPException, Query, Path
+from fastapi.responses import FileResponse
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from typing import Dict, Any, Optional
 from typing import Dict, Any, Optional
 import logging
 import logging
+import os
 from ..services.cd_flux_removal_service import CdFluxRemovalService
 from ..services.cd_flux_removal_service import CdFluxRemovalService
 
 
 router = APIRouter()
 router = APIRouter()
@@ -26,6 +28,18 @@ class CdFluxRemovalResponse(BaseModel):
     data: Optional[Dict[str, Any]] = Field(None, description="计算结果数据")
     data: Optional[Dict[str, Any]] = Field(None, description="计算结果数据")
 
 
 
 
+class VisualizationResponse(BaseModel):
+    """
+    可视化结果响应模型
+    
+    @description: 绘图接口的响应格式
+    """
+    success: bool = Field(..., description="是否成功")
+    message: str = Field(..., description="响应消息")
+    data: Optional[Dict[str, Any]] = Field(None, description="可视化结果数据")
+    files: Optional[Dict[str, str]] = Field(None, description="生成的文件路径")
+
+
 # 设置日志
 # 设置日志
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -113,4 +127,217 @@ async def calculate_straw_removal(
         )
         )
 
 
 
 
+# =============================================================================
+# Cd通量移除可视化接口
+# =============================================================================
+
+@router.get("/grain-removal/visualize",
+           summary="生成籽粒移除Cd通量可视化图表",
+           description="计算籽粒移除Cd通量并生成栅格地图并返回图片文件")
+async def visualize_grain_removal(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
+    resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
+    enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
+    cleanup_intermediate: bool = Query(True, description="是否清理中间文件(默认是)")
+):
+    """
+    生成籽粒移除Cd通量可视化图表
+    
+    @param area: 地区名称
+    @returns: 栅格地图文件
+    
+    功能包括:
+    1. 计算籽粒移除Cd通量
+    2. 生成栅格地图
+    3. 直接返回图片文件
+    """
+    try:
+        service = CdFluxRemovalService()
+        
+        # 计算籽粒移除Cd通量
+        calc_result = service.calculate_grain_removal_by_area(area)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 获取包含坐标的结果数据
+        results_with_coords = service.get_coordinates_for_results(calc_result["data"])
+        
+        if not results_with_coords:
+            raise HTTPException(
+                status_code=404,
+                detail=f"未找到地区 '{area}' 的坐标数据,无法生成可视化"
+            )
+        
+        # 创建可视化
+        visualization_files = service.create_flux_visualization(
+            area=area,
+            calculation_type="grain_removal",
+            results_with_coords=results_with_coords,
+            colormap=colormap,
+            resolution_factor=resolution_factor,
+            enable_interpolation=enable_interpolation,
+            cleanup_intermediate=cleanup_intermediate
+        )
+        
+        # 检查地图文件是否生成成功
+        map_file = visualization_files.get("map")
+        if not map_file or not os.path.exists(map_file):
+            raise HTTPException(status_code=500, detail="地图文件生成失败")
+        
+        return FileResponse(
+            path=map_file,
+            filename=f"{area}_grain_removal_cd_flux_map.jpg",
+            media_type="image/jpeg"
+        )
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"生成地区 '{area}' 的籽粒移除Cd通量可视化失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"可视化生成失败: {str(e)}"
+        )
+
+
+@router.get("/straw-removal/visualize",
+           summary="生成秸秆移除Cd通量可视化图表",
+           description="计算秸秆移除Cd通量并生成栅格地图并返回图片文件")
+async def visualize_straw_removal(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
+    resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
+    enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
+    cleanup_intermediate: bool = Query(True, description="是否清理中间文件(默认是)")
+):
+    """
+    生成秸秆移除Cd通量可视化图表
+    
+    @param area: 地区名称
+    @returns: 栅格地图文件
+    
+    功能包括:
+    1. 计算秸秆移除Cd通量
+    2. 生成栅格地图
+    3. 直接返回图片文件
+    """
+    try:
+        service = CdFluxRemovalService()
+        
+        # 计算秸秆移除Cd通量
+        calc_result = service.calculate_straw_removal_by_area(area)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 获取包含坐标的结果数据
+        results_with_coords = service.get_coordinates_for_results(calc_result["data"])
+        
+        if not results_with_coords:
+            raise HTTPException(
+                status_code=404,
+                detail=f"未找到地区 '{area}' 的坐标数据,无法生成可视化"
+            )
+        
+        # 创建可视化
+        visualization_files = service.create_flux_visualization(
+            area=area,
+            calculation_type="straw_removal",
+            results_with_coords=results_with_coords,
+            colormap=colormap,
+            resolution_factor=resolution_factor,
+            enable_interpolation=enable_interpolation,
+            cleanup_intermediate=cleanup_intermediate
+        )
+        
+        # 检查地图文件是否生成成功
+        map_file = visualization_files.get("map")
+        if not map_file or not os.path.exists(map_file):
+            raise HTTPException(status_code=500, detail="地图文件生成失败")
+        
+        return FileResponse(
+            path=map_file,
+            filename=f"{area}_straw_removal_cd_flux_map.jpg",
+            media_type="image/jpeg"
+        )
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"生成地区 '{area}' 的秸秆移除Cd通量可视化失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"可视化生成失败: {str(e)}"
+        )
+
+
+@router.get("/export-data",
+           summary="导出Cd通量移除计算数据",
+           description="导出籽粒移除或秸秆移除的计算结果为CSV文件",
+           response_model=CdFluxRemovalResponse)
+async def export_flux_data(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    calculation_type: str = Query(..., description="计算类型:grain_removal 或 straw_removal")
+) -> Dict[str, Any]:
+    """
+    导出Cd通量移除计算数据
+    
+    @param area: 地区名称
+    @param calculation_type: 计算类型(grain_removal 或 straw_removal)
+    @returns: 导出结果和文件路径
+    """
+    try:
+        if calculation_type not in ["grain_removal", "straw_removal"]:
+            raise HTTPException(
+                status_code=400,
+                detail="计算类型必须是 'grain_removal' 或 'straw_removal'"
+            )
+        
+        service = CdFluxRemovalService()
+        
+        # 根据类型计算相应的Cd通量
+        if calculation_type == "grain_removal":
+            calc_result = service.calculate_grain_removal_by_area(area)
+        else:
+            calc_result = service.calculate_straw_removal_by_area(area)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 导出数据
+        csv_path = service.export_results_to_csv(calc_result["data"])
+        
+        return {
+            "success": True,
+            "message": f"地区 '{area}' 的 {calculation_type} 数据导出成功",
+            "data": {
+                "area": area,
+                "calculation_type": calculation_type,
+                "exported_file": csv_path,
+                "total_records": len(calc_result["data"]["results"]),
+                "statistics": calc_result["data"]["statistics"]
+            }
+        }
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"导出地区 '{area}' 的 {calculation_type} 数据失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"数据导出失败: {str(e)}"
+        )
+
+
  
  

+ 363 - 0
app/services/cd_flux_removal_service.py

@@ -7,12 +7,20 @@ Cd通量移除计算服务
 
 
 import logging
 import logging
 import math
 import math
+import os
+import pandas as pd
+from datetime import datetime
 from typing import Dict, Any, List, Optional
 from typing import Dict, Any, List, Optional
 from sqlalchemy.orm import sessionmaker, Session
 from sqlalchemy.orm import sessionmaker, Session
 from sqlalchemy import create_engine, and_
 from sqlalchemy import create_engine, and_
 from ..database import SessionLocal, engine
 from ..database import SessionLocal, engine
 from ..models.parameters import Parameters
 from ..models.parameters import Parameters
 from ..models.CropCd_output import CropCdOutputData
 from ..models.CropCd_output import CropCdOutputData
+from ..models.farmland import FarmlandData
+from ..utils.mapping_utils import MappingUtils
+from .admin_boundary_service import get_boundary_geojson_by_name
+import tempfile
+import json
 
 
 
 
 class CdFluxRemovalService:
 class CdFluxRemovalService:
@@ -198,4 +206,359 @@ class CdFluxRemovalService:
                 "data": None
                 "data": None
             }
             }
     
     
+    def export_results_to_csv(self, results_data: Dict[str, Any], output_dir: str = "app/static/cd_flux") -> str:
+        """
+        将计算结果导出为CSV文件
+        
+        @param results_data: 计算结果数据
+        @param output_dir: 输出目录
+        @returns: CSV文件路径
+        """
+        try:
+            # 确保输出目录存在
+            os.makedirs(output_dir, exist_ok=True)
+            
+            # 生成时间戳
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            
+            # 生成文件名
+            calculation_type = results_data.get("calculation_type", "flux_removal")
+            area = results_data.get("area", "unknown")
+            filename = f"{calculation_type}_{area}_{timestamp}.csv"
+            csv_path = os.path.join(output_dir, filename)
+            
+            # 转换为DataFrame
+            results = results_data.get("results", [])
+            if not results:
+                raise ValueError("没有结果数据可导出")
+            
+            df = pd.DataFrame(results)
+            
+            # 保存CSV文件
+            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
+            
+            self.logger.info(f"✓ 成功导出结果到: {csv_path}")
+            return csv_path
+            
+        except Exception as e:
+            self.logger.error(f"导出CSV文件失败: {str(e)}")
+            raise
+    
+    def get_coordinates_for_results(self, results_data: Dict[str, Any]) -> List[Dict[str, Any]]:
+        """
+        获取结果数据对应的坐标信息
+        
+        @param results_data: 计算结果数据
+        @returns: 包含坐标的结果列表
+        """
+        try:
+            results = results_data.get("results", [])
+            if not results:
+                return []
+
+            # 提取成对键,避免 N 次数据库查询
+            farmland_sample_pairs = [(r["farmland_id"], r["sample_id"]) for r in results]
+
+            with SessionLocal() as db:
+                # 使用 farmland_id 分片查询,避免复合 IN 导致的兼容性与参数数量问题
+                wanted_pairs = set(farmland_sample_pairs)
+                unique_farmland_ids = sorted({fid for fid, _ in wanted_pairs})
+
+                def chunk_list(items: List[int], chunk_size: int = 500) -> List[List[int]]:
+                    return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
+
+                rows: List[FarmlandData] = []
+                for id_chunk in chunk_list(unique_farmland_ids, 500):
+                    rows.extend(
+                        db.query(FarmlandData)
+                          .filter(FarmlandData.farmland_id.in_(id_chunk))
+                          .all()
+                    )
+
+                pair_to_farmland = {
+                    (row.farmland_id, row.sample_id): row for row in rows
+                }
+
+                coordinates_results: List[Dict[str, Any]] = []
+                for r in results:
+                    key = (r["farmland_id"], r["sample_id"])
+                    farmland = pair_to_farmland.get(key)
+                    if farmland is None:
+                        continue
+
+                    coord_result = {
+                        "farmland_id": r["farmland_id"],
+                        "sample_id": r["sample_id"],
+                        "longitude": farmland.lon,
+                        "latitude": farmland.lan,
+                        "flux_value": r.get("grain_removal_flux") or r.get("straw_removal_flux")
+                    }
+                    coord_result.update(r)
+                    coordinates_results.append(coord_result)
+
+                self.logger.info(f"✓ 成功获取 {len(coordinates_results)} 个样点的坐标信息(分片批量查询)")
+                return coordinates_results
+                
+        except Exception as e:
+            self.logger.error(f"获取坐标信息失败: {str(e)}")
+            raise
+    
+    def create_flux_visualization(self, area: str, calculation_type: str,
+                                results_with_coords: List[Dict[str, Any]],
+                                output_dir: str = "app/static/cd_flux",
+                                template_raster: str = "app/static/cd_flux/meanTemp.tif",
+                                boundary_shp: str = None,
+                                colormap: str = "green_yellow_red_purple",
+                                resolution_factor: float = 4.0,
+                                enable_interpolation: bool = False,
+                                cleanup_intermediate: bool = True) -> Dict[str, str]:
+        """
+        创建Cd通量移除可视化图表
+        
+        @param area: 地区名称
+        @param calculation_type: 计算类型(grain_removal 或 straw_removal)
+        @param results_with_coords: 包含坐标的结果数据
+        @param output_dir: 输出目录
+        @param template_raster: 模板栅格文件路径
+        @param boundary_shp: 边界shapefile路径
+        @param colormap: 色彩方案
+        @param resolution_factor: 分辨率因子
+        @param enable_interpolation: 是否启用空间插值
+        @returns: 生成的图片文件路径字典
+        """
+        try:
+            if not results_with_coords:
+                raise ValueError("没有包含坐标的结果数据")
+            
+            # 确保输出目录存在
+            os.makedirs(output_dir, exist_ok=True)
+            
+            # 生成时间戳
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            
+            # 创建CSV文件用于绘图
+            csv_filename = f"{calculation_type}_{area}_temp_{timestamp}.csv"
+            csv_path = os.path.join(output_dir, csv_filename)
+            
+            # 准备绘图数据
+            plot_data = []
+            for result in results_with_coords:
+                plot_data.append({
+                    "longitude": result["longitude"],
+                    "latitude": result["latitude"],
+                    "flux_value": result["flux_value"]
+                })
+            
+            # 保存为CSV
+            df = pd.DataFrame(plot_data)
+            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
+            
+            # 初始化绘图工具
+            mapper = MappingUtils()
+            
+            # 生成输出文件路径
+            map_output = os.path.join(output_dir, f"{calculation_type}_{area}_map_{timestamp}")
+            histogram_output = os.path.join(output_dir, f"{calculation_type}_{area}_histogram_{timestamp}")
+            
+            # 检查模板文件是否存在
+            if not os.path.exists(template_raster):
+                self.logger.warning(f"模板栅格文件不存在: {template_raster}")
+                template_raster = None
+            
+            # 动态获取边界文件
+            boundary_shp = self._get_boundary_file_for_area(area)
+            if not boundary_shp:
+                self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪")
+            
+            # 创建shapefile
+            shapefile_path = csv_path.replace('.csv', '_points.shp')
+            mapper.csv_to_shapefile(csv_path, shapefile_path, 
+                                  lon_col='longitude', lat_col='latitude', value_col='flux_value')
+            
+            generated_files = {"csv": csv_path, "shapefile": shapefile_path}
+            
+            # 如果有模板栅格文件,创建栅格地图
+            if template_raster:
+                try:
+                    # 创建栅格
+                    raster_path = csv_path.replace('.csv', '_raster.tif')
+                    raster_path, stats = mapper.vector_to_raster(
+                        shapefile_path, template_raster, raster_path, 'flux_value',
+                        resolution_factor=resolution_factor, boundary_shp=boundary_shp,
+                        interpolation_method='nearest', enable_interpolation=enable_interpolation
+                    )
+                    generated_files["raster"] = raster_path
+                    
+                    # 创建栅格地图 - 使用英文标题避免中文乱码
+                    title_mapping = {
+                        "grain_removal": "Grain Removal Cd Flux",
+                        "straw_removal": "Straw Removal Cd Flux"
+                    }
+                    map_title = title_mapping.get(calculation_type, "Cd Flux Removal")
+                    
+                    map_file = mapper.create_raster_map(
+                        boundary_shp if boundary_shp else None,
+                        raster_path,
+                        map_output,
+                        colormap=colormap,
+                        title=map_title,
+                        output_size=12,
+                        dpi=300,
+                        resolution_factor=4.0,
+                        enable_interpolation=False,
+                        interpolation_method='nearest'
+                    )
+                    generated_files["map"] = map_file
+                    
+                    # 创建直方图 - 使用英文标题避免中文乱码
+                    histogram_title_mapping = {
+                        "grain_removal": "Grain Removal Cd Flux Distribution",
+                        "straw_removal": "Straw Removal Cd Flux Distribution"
+                    }
+                    histogram_title = histogram_title_mapping.get(calculation_type, "Cd Flux Distribution")
+                    
+                    histogram_file = mapper.create_histogram(
+                        raster_path,
+                        f"{histogram_output}.jpg",
+                        title=histogram_title,
+                        xlabel='Cd Flux (g/ha/a)',
+                        ylabel='Frequency Density'
+                    )
+                    generated_files["histogram"] = histogram_file
+                    
+                except Exception as viz_error:
+                    self.logger.warning(f"栅格可视化创建失败: {str(viz_error)}")
+                    # 即使栅格可视化失败,也返回已生成的文件
+            
+            # 清理中间文件(默认开启,仅保留最终可视化)
+            if cleanup_intermediate:
+                try:
+                    self._cleanup_intermediate_files(generated_files, boundary_shp)
+                except Exception as cleanup_err:
+                    self.logger.warning(f"中间文件清理失败: {str(cleanup_err)}")
+
+            self.logger.info(f"✓ 成功创建 {calculation_type} 可视化,生成文件: {list(generated_files.keys())}")
+            return generated_files
+            
+        except Exception as e:
+            self.logger.error(f"创建可视化失败: {str(e)}")
+            raise
+
+    def _cleanup_intermediate_files(self, generated_files: Dict[str, str], boundary_shp: Optional[str]) -> None:
+        """
+        清理中间文件:CSV、Shapefile 及其配套文件、栅格TIFF;若边界为临时目录,则一并删除
+        """
+        import shutil
+        import tempfile
+
+        def _safe_remove(path: str) -> None:
+            try:
+                if path and os.path.exists(path) and os.path.isfile(path):
+                    os.remove(path)
+            except Exception:
+                pass
+
+        # 删除 CSV
+        _safe_remove(generated_files.get("csv"))
+
+        # 删除栅格
+        _safe_remove(generated_files.get("raster"))
+
+        # 删除 Shapefile 全家桶
+        shp_path = generated_files.get("shapefile")
+        if shp_path:
+            base, _ = os.path.splitext(shp_path)
+            for ext in (".shp", ".shx", ".dbf", ".prj", ".cpg"):
+                _safe_remove(base + ext)
+
+        # 如果边界文件来自系统临时目录,删除其所在目录
+        if boundary_shp:
+            temp_root = tempfile.gettempdir()
+            try:
+                if os.path.commonprefix([os.path.abspath(boundary_shp), temp_root]) == temp_root:
+                    temp_dir = os.path.dirname(os.path.abspath(boundary_shp))
+                    if os.path.isdir(temp_dir):
+                        shutil.rmtree(temp_dir, ignore_errors=True)
+            except Exception:
+                pass
+    
+    def _get_boundary_file_for_area(self, area: str) -> Optional[str]:
+        """
+        为指定地区获取边界文件
+        
+        @param area: 地区名称
+        @returns: 边界文件路径或None
+        """
+        try:
+            # 首先尝试静态文件路径(只查找该地区专用的边界文件)
+            norm_area = area.strip()
+            base_name = norm_area.replace('市', '').replace('县', '')
+            name_variants = list(dict.fromkeys([
+                norm_area,
+                base_name,
+                f"{base_name}市",
+            ]))
+            static_boundary_paths = []
+            for name in name_variants:
+                static_boundary_paths.append(f"app/static/cd_flux/{name}.shp")
+            
+            for path in static_boundary_paths:
+                if os.path.exists(path):
+                    self.logger.info(f"找到边界文件: {path}")
+                    return path
+            
+            # 优先从数据库获取边界数据(对名称进行多变体匹配,如 “韶关/韶关市”)
+            boundary_path = self._create_boundary_from_database(area)
+            if boundary_path:
+                return boundary_path
+            
+            # 如果都没有找到,记录警告但不使用默认文件
+            self.logger.warning(f"未找到地区 '{area}' 的专用边界文件,也无法从数据库获取")
+            return None
+            
+        except Exception as e:
+            self.logger.error(f"获取边界文件失败: {str(e)}")
+            return None
+    
+    def _create_boundary_from_database(self, area: str) -> Optional[str]:
+        """
+        从数据库获取边界数据并创建临时shapefile
+        
+        @param area: 地区名称
+        @returns: 临时边界文件路径或None
+        """
+        try:
+            with SessionLocal() as db:
+                # 生成名称变体,增强匹配鲁棒性
+                norm_area = area.strip()
+                base_name = norm_area.replace('市', '').replace('县', '')
+                candidates = list(dict.fromkeys([
+                    norm_area,
+                    base_name,
+                    f"{base_name}市",
+                ]))
+
+                for candidate in candidates:
+                    try:
+                        boundary_geojson = get_boundary_geojson_by_name(db, candidate, level="auto")
+                        if boundary_geojson:
+                            # 创建临时shapefile
+                            import geopandas as gpd
+                            from shapely.geometry import shape
+                            geometry = shape(boundary_geojson["geometry"])
+                            gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry], crs="EPSG:4326")
+                            temp_dir = tempfile.mkdtemp()
+                            boundary_path = os.path.join(temp_dir, f"{candidate}_boundary.shp")
+                            gdf.to_file(boundary_path, driver="ESRI Shapefile")
+                            self.logger.info(f"从数据库创建边界文件: {boundary_path}")
+                            return boundary_path
+                    except Exception as _:
+                        # 尝试下一个候选名称
+                        continue
+                    
+        except Exception as e:
+            self.logger.warning(f"从数据库创建边界文件失败: {str(e)}")
+            
+        return None
+    
  
  

+ 100 - 3
app/utils/mapping_utils.py

@@ -12,10 +12,12 @@ import pandas as pd
 import geopandas as gpd
 import geopandas as gpd
 from shapely.geometry import Point
 from shapely.geometry import Point
 import rasterio
 import rasterio
-from rasterio.features import rasterize
+from rasterio.features import rasterize, geometry_mask
 from rasterio.transform import from_origin
 from rasterio.transform import from_origin
 from rasterio.mask import mask
 from rasterio.mask import mask
 from rasterio.plot import show
 from rasterio.plot import show
+from rasterio.warp import transform_bounds, reproject
+from rasterio.enums import Resampling
 import numpy as np
 import numpy as np
 import os
 import os
 import json
 import json
@@ -259,7 +261,8 @@ class MappingUtils:
                 dtype='float32'
                 dtype='float32'
             )
             )
             
             
-            # 应用边界掩膜(如果提供)
+            # 预备掩膜:优先使用行政区边界;若未提供边界,则使用点集凸包限制绘制范围
+            boundary_mask = None
             if boundary_shp and os.path.exists(boundary_shp):
             if boundary_shp and os.path.exists(boundary_shp):
                 self.logger.info(f"应用边界掩膜: {boundary_shp}")
                 self.logger.info(f"应用边界掩膜: {boundary_shp}")
                 boundary_gdf = gpd.read_file(boundary_shp)
                 boundary_gdf = gpd.read_file(boundary_shp)
@@ -267,11 +270,24 @@ class MappingUtils:
                     boundary_gdf = boundary_gdf.to_crs(crs)
                     boundary_gdf = boundary_gdf.to_crs(crs)
                 boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
                 boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
                 raster[~boundary_mask] = np.nan
                 raster[~boundary_mask] = np.nan
+            else:
+                try:
+                    # 使用点集凸包作为默认掩膜,避免边界外着色
+                    hull = gdf.unary_union.convex_hull
+                    hull_gdf = gpd.GeoDataFrame(geometry=[hull], crs=crs)
+                    boundary_mask = self.create_boundary_mask(raster, transform, hull_gdf)
+                    raster[~boundary_mask] = np.nan
+                    self.logger.info("已使用点集凸包限制绘制范围")
+                except Exception as hull_err:
+                    self.logger.warning(f"生成点集凸包掩膜失败,可能会出现边界外着色: {str(hull_err)}")
             
             
             # 使用插值方法填充NaN值(如果启用)
             # 使用插值方法填充NaN值(如果启用)
             if enable_interpolation and np.isnan(raster).any():
             if enable_interpolation and np.isnan(raster).any():
                 self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
                 self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
                 raster = self.interpolate_nan_values(raster, method=interpolation_method)
                 raster = self.interpolate_nan_values(raster, method=interpolation_method)
+                # 关键修正:插值后再次应用掩膜,确保边界外不被填充
+                if boundary_mask is not None:
+                    raster[~boundary_mask] = np.nan
             elif not enable_interpolation and np.isnan(raster).any():
             elif not enable_interpolation and np.isnan(raster).any():
                 self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
                 self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
             
             
@@ -373,7 +389,13 @@ class MappingUtils:
             # 应用分辨率因子重采样
             # 应用分辨率因子重采样
             if resolution_factor != 1.0:
             if resolution_factor != 1.0:
                 self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
                 self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
-                raster, out_transform = self._resample_raster(raster, out_transform, resolution_factor, original_crs)
+                raster, out_transform = self._resample_raster(
+                    raster=raster,
+                    transform=out_transform,
+                    resolution_factor=resolution_factor,
+                    crs=original_crs,
+                    resampling='nearest'
+                )
 
 
             # 应用空间插值(如果启用)
             # 应用空间插值(如果启用)
             if enable_interpolation and np.isnan(raster).any():
             if enable_interpolation and np.isnan(raster).any():
@@ -408,6 +430,23 @@ class MappingUtils:
 
 
             # 绘图
             # 绘图
             fig, ax = plt.subplots(figsize=fig_size)
             fig, ax = plt.subplots(figsize=fig_size)
+            
+            # 如果有边界文件,需要进一步mask边界外的区域;否则使用栅格有效范围
+            if gdf is not None:
+                try:
+                    height, width = raster.shape
+                    transform = out_transform
+                    geom_mask = geometry_mask(
+                        [json.loads(gdf.to_json())["features"][0]["geometry"]], 
+                        out_shape=(height, width), 
+                        transform=transform,
+                        invert=True
+                    )
+                    raster = np.where(geom_mask, raster, np.nan)
+                except Exception as mask_err:
+                    self.logger.warning(f"边界掩膜应用失败,将继续绘制已裁剪栅格: {str(mask_err)}")
+            
+            # 显示栅格数据
             show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
             show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
 
 
             # 添加矢量边界
             # 添加矢量边界
@@ -449,6 +488,64 @@ class MappingUtils:
         except Exception as e:
         except Exception as e:
             self.logger.error(f"栅格地图创建失败: {str(e)}")
             self.logger.error(f"栅格地图创建失败: {str(e)}")
             raise
             raise
+
+    def _resample_raster(self, raster, transform, resolution_factor: float, crs, resampling: str = 'nearest'):
+        """
+        按分辨率因子对二维栅格进行重采样,并返回新栅格与更新后的仿射变换
+        
+        @param raster: 2D numpy 数组
+        @param transform: 输入栅格的仿射变换
+        @param resolution_factor: 分辨率因子 (>1 增加像元密度)
+        @param crs: 坐标参考系
+        @param resampling: 重采样方式 ('nearest' | 'bilinear' | 'cubic')
+        @return: (resampled_raster, new_transform)
+        """
+        try:
+            if resolution_factor == 1.0:
+                return raster, transform
+
+            rows, cols = raster.shape
+            new_rows = max(1, int(rows * resolution_factor))
+            new_cols = max(1, int(cols * resolution_factor))
+
+            # 更新变换(像元变小)
+            new_transform = rasterio.Affine(
+                transform.a / resolution_factor,
+                transform.b,
+                transform.c,
+                transform.d,
+                transform.e / resolution_factor,
+                transform.f
+            )
+
+            # 选择重采样算法
+            resampling_map = {
+                'nearest': Resampling.nearest,
+                'bilinear': Resampling.bilinear,
+                'cubic': Resampling.cubic,
+            }
+            resampling_enum = resampling_map.get(resampling, Resampling.nearest)
+
+            destination = np.full((new_rows, new_cols), np.nan, dtype='float32')
+
+            # 使用 reproject 做重采样(坐标系不变,仅分辨率变化)
+            reproject(
+                source=raster,
+                destination=destination,
+                src_transform=transform,
+                src_crs=crs,
+                dst_transform=new_transform,
+                dst_crs=crs,
+                src_nodata=np.nan,
+                dst_nodata=np.nan,
+                resampling=resampling_enum
+            )
+
+            return destination, new_transform
+        except Exception as e:
+            self.logger.error(f"重采样失败: {str(e)}")
+            # 失败则返回原始数据,避免中断
+            return raster, transform
     
     
     def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
     def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
                         xlabel='像元值', ylabel='频率密度', title='数值分布图',
                         xlabel='像元值', ylabel='频率密度', title='数值分布图',