3 Commits e7fe25ddeb ... 614a33cc07

Author SHA1 Message Date
  drggboy 614a33cc07 更新数据库备份文件,更新数据库备份&导入脚本 2 weeks ago
  drggboy f7b930054d 添加行政层级参数到籽粒和秸秆移除Cd通量可视化接口,增强服务逻辑以支持严格匹配层级,更新边界文件获取逻辑以确保精确匹配。 2 weeks ago
  drggboy 2cb8131240 添加籽粒和秸秆移除Cd通量可视化接口,支持生成栅格地图和导出计算数据为CSV文件;更新服务逻辑以处理可视化和数据导出功能。 2 weeks ago

+ 0 - 1
.gitignore

@@ -13,4 +13,3 @@ myenv/
 .vscode/
 *.log
 Cd_Prediction_Integrated_System/output/raster/meanTemp.tif.aux.xml
-config.env

+ 1 - 1
PROJECT_RULES.md

@@ -571,7 +571,7 @@ pg_dump -U postgres soilgd > soilgd.sql
 
 #### 导入操作
 ```bash
-# 导入自定义格式备份
+# 导入自定义格式备份,若soilgd已存在,需要删除该数据库重新创建
 createdb -U postgres soilgd
 pg_restore -U postgres -d soilgd soilgd.dump
 

+ 239 - 0
app/api/cd_flux_removal.py

@@ -4,9 +4,11 @@ Cd通量移除计算API接口
 """
 
 from fastapi import APIRouter, HTTPException, Query, Path
+from fastapi.responses import FileResponse
 from pydantic import BaseModel, Field
 from typing import Dict, Any, Optional
 import logging
+import os
 from ..services.cd_flux_removal_service import CdFluxRemovalService
 
 router = APIRouter()
@@ -26,6 +28,18 @@ class CdFluxRemovalResponse(BaseModel):
     data: Optional[Dict[str, Any]] = Field(None, description="计算结果数据")
 
 
+class VisualizationResponse(BaseModel):
+    """
+    可视化结果响应模型
+    
+    @description: 绘图接口的响应格式
+    """
+    success: bool = Field(..., description="是否成功")
+    message: str = Field(..., description="响应消息")
+    data: Optional[Dict[str, Any]] = Field(None, description="可视化结果数据")
+    files: Optional[Dict[str, str]] = Field(None, description="生成的文件路径")
+
+
 # 设置日志
 logger = logging.getLogger(__name__)
 
@@ -113,4 +127,229 @@ async def calculate_straw_removal(
         )
 
 
+# =============================================================================
+# Cd通量移除可视化接口
+# =============================================================================
+
+@router.get("/grain-removal/visualize",
+           summary="生成籽粒移除Cd通量可视化图表",
+           description="计算籽粒移除Cd通量并生成栅格地图并返回图片文件")
+async def visualize_grain_removal(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    level: str = Query(..., description="行政层级,必须为: county | city | province"),
+    colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
+    resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
+    enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
+    cleanup_intermediate: bool = Query(True, description="是否清理中间文件(默认是)")
+):
+    """
+    生成籽粒移除Cd通量可视化图表
+    
+    @param area: 地区名称
+    @returns: 栅格地图文件
+    
+    功能包括:
+    1. 计算籽粒移除Cd通量
+    2. 生成栅格地图
+    3. 直接返回图片文件
+    """
+    try:
+        service = CdFluxRemovalService()
+
+        # 行政层级校验(不允许模糊)
+        if level not in ("county", "city", "province"):
+            raise HTTPException(status_code=400, detail="参数 level 必须为 'county' | 'city' | 'province'")
+        
+        # 计算籽粒移除Cd通量(传入严格层级以便参数表查找做后缀标准化精确匹配)
+        calc_result = service.calculate_grain_removal_by_area(area, level=level)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 获取包含坐标的结果数据
+        results_with_coords = service.get_coordinates_for_results(calc_result["data"])
+        
+        if not results_with_coords:
+            raise HTTPException(
+                status_code=404,
+                detail=f"未找到地区 '{area}' 的坐标数据,无法生成可视化"
+            )
+        
+        # 创建可视化
+        visualization_files = service.create_flux_visualization(
+            area=area,
+            level=level,
+            calculation_type="grain_removal",
+            results_with_coords=results_with_coords,
+            colormap=colormap,
+            resolution_factor=resolution_factor,
+            enable_interpolation=enable_interpolation,
+            cleanup_intermediate=cleanup_intermediate
+        )
+        
+        # 检查地图文件是否生成成功
+        map_file = visualization_files.get("map")
+        if not map_file or not os.path.exists(map_file):
+            raise HTTPException(status_code=500, detail="地图文件生成失败")
+        
+        return FileResponse(
+            path=map_file,
+            filename=f"{area}_grain_removal_cd_flux_map.jpg",
+            media_type="image/jpeg"
+        )
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"生成地区 '{area}' 的籽粒移除Cd通量可视化失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"可视化生成失败: {str(e)}"
+        )
+
+
+@router.get("/straw-removal/visualize",
+           summary="生成秸秆移除Cd通量可视化图表",
+           description="计算秸秆移除Cd通量并生成栅格地图并返回图片文件")
+async def visualize_straw_removal(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    level: str = Query(..., description="行政层级,必须为: county | city | province"),
+    colormap: str = Query("green_yellow_red_purple", description="色彩方案"),
+    resolution_factor: float = Query(4.0, description="分辨率因子(默认4.0,更快)"),
+    enable_interpolation: bool = Query(False, description="是否启用空间插值(默认关闭以提升性能)"),
+    cleanup_intermediate: bool = Query(True, description="是否清理中间文件(默认是)")
+):
+    """
+    生成秸秆移除Cd通量可视化图表
+    
+    @param area: 地区名称
+    @returns: 栅格地图文件
+    
+    功能包括:
+    1. 计算秸秆移除Cd通量
+    2. 生成栅格地图
+    3. 直接返回图片文件
+    """
+    try:
+        service = CdFluxRemovalService()
+
+        # 行政层级校验(不允许模糊)
+        if level not in ("county", "city", "province"):
+            raise HTTPException(status_code=400, detail="参数 level 必须为 'county' | 'city' | 'province'")
+        
+        # 计算秸秆移除Cd通量(传入严格层级以便参数表查找做后缀标准化精确匹配)
+        calc_result = service.calculate_straw_removal_by_area(area, level=level)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 获取包含坐标的结果数据
+        results_with_coords = service.get_coordinates_for_results(calc_result["data"])
+        
+        if not results_with_coords:
+            raise HTTPException(
+                status_code=404,
+                detail=f"未找到地区 '{area}' 的坐标数据,无法生成可视化"
+            )
+        
+        # 创建可视化
+        visualization_files = service.create_flux_visualization(
+            area=area,
+            level=level,
+            calculation_type="straw_removal",
+            results_with_coords=results_with_coords,
+            colormap=colormap,
+            resolution_factor=resolution_factor,
+            enable_interpolation=enable_interpolation,
+            cleanup_intermediate=cleanup_intermediate
+        )
+        
+        # 检查地图文件是否生成成功
+        map_file = visualization_files.get("map")
+        if not map_file or not os.path.exists(map_file):
+            raise HTTPException(status_code=500, detail="地图文件生成失败")
+        
+        return FileResponse(
+            path=map_file,
+            filename=f"{area}_straw_removal_cd_flux_map.jpg",
+            media_type="image/jpeg"
+        )
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"生成地区 '{area}' 的秸秆移除Cd通量可视化失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"可视化生成失败: {str(e)}"
+        )
+
+
+@router.get("/export-data",
+           summary="导出Cd通量移除计算数据",
+           description="导出籽粒移除或秸秆移除的计算结果为CSV文件",
+           response_model=CdFluxRemovalResponse)
+async def export_flux_data(
+    area: str = Query(..., description="地区名称,如:韶关"),
+    calculation_type: str = Query(..., description="计算类型:grain_removal 或 straw_removal")
+) -> Dict[str, Any]:
+    """
+    导出Cd通量移除计算数据
+    
+    @param area: 地区名称
+    @param calculation_type: 计算类型(grain_removal 或 straw_removal)
+    @returns: 导出结果和文件路径
+    """
+    try:
+        if calculation_type not in ["grain_removal", "straw_removal"]:
+            raise HTTPException(
+                status_code=400,
+                detail="计算类型必须是 'grain_removal' 或 'straw_removal'"
+            )
+        
+        service = CdFluxRemovalService()
+        
+        # 根据类型计算相应的Cd通量
+        if calculation_type == "grain_removal":
+            calc_result = service.calculate_grain_removal_by_area(area)
+        else:
+            calc_result = service.calculate_straw_removal_by_area(area)
+        
+        if not calc_result["success"]:
+            raise HTTPException(
+                status_code=404, 
+                detail=calc_result["message"]
+            )
+        
+        # 导出数据
+        csv_path = service.export_results_to_csv(calc_result["data"])
+        
+        return {
+            "success": True,
+            "message": f"地区 '{area}' 的 {calculation_type} 数据导出成功",
+            "data": {
+                "area": area,
+                "calculation_type": calculation_type,
+                "exported_file": csv_path,
+                "total_records": len(calc_result["data"]["results"]),
+                "statistics": calc_result["data"]["statistics"]
+            }
+        }
+        
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"导出地区 '{area}' 的 {calculation_type} 数据失败: {str(e)}")
+        raise HTTPException(
+            status_code=500, 
+            detail=f"数据导出失败: {str(e)}"
+        )
+
+
  

+ 397 - 4
app/services/cd_flux_removal_service.py

@@ -7,12 +7,22 @@ Cd通量移除计算服务
 
 import logging
 import math
+import os
+import pandas as pd
+from datetime import datetime
 from typing import Dict, Any, List, Optional
 from sqlalchemy.orm import sessionmaker, Session
 from sqlalchemy import create_engine, and_
 from ..database import SessionLocal, engine
 from ..models.parameters import Parameters
 from ..models.CropCd_output import CropCdOutputData
+from ..models.farmland import FarmlandData
+from ..utils.mapping_utils import MappingUtils
+from .admin_boundary_service import get_boundary_geojson_by_name
+import geopandas as gpd
+from shapely.geometry import shape, Point
+import tempfile
+import json
 
 
 class CdFluxRemovalService:
@@ -27,8 +37,9 @@ class CdFluxRemovalService:
         初始化Cd通量移除服务
         """
         self.logger = logging.getLogger(__name__)
+        # 严格匹配策略:不做名称变体或后缀映射
         
-    def calculate_grain_removal_by_area(self, area: str) -> Dict[str, Any]:
+    def calculate_grain_removal_by_area(self, area: str, level: Optional[str] = None) -> Dict[str, Any]:
         """
         根据地区计算籽粒移除Cd通量
         
@@ -39,7 +50,7 @@ class CdFluxRemovalService:
         """
         try:
             with SessionLocal() as db:
-                # 查询指定地区的参数
+                # 查询指定地区的参数(严格等号匹配,不做任何映射)
                 parameter = db.query(Parameters).filter(Parameters.area == area).first()
                 
                 if not parameter:
@@ -104,7 +115,7 @@ class CdFluxRemovalService:
                 "data": None
             }
     
-    def calculate_straw_removal_by_area(self, area: str) -> Dict[str, Any]:
+    def calculate_straw_removal_by_area(self, area: str, level: Optional[str] = None) -> Dict[str, Any]:
         """
         根据地区计算秸秆移除Cd通量
         
@@ -115,7 +126,7 @@ class CdFluxRemovalService:
         """
         try:
             with SessionLocal() as db:
-                # 查询指定地区的参数
+                # 查询指定地区的参数(严格等号匹配,不做任何映射)
                 parameter = db.query(Parameters).filter(Parameters.area == area).first()
                 
                 if not parameter:
@@ -198,4 +209,386 @@ class CdFluxRemovalService:
                 "data": None
             }
     
+    def export_results_to_csv(self, results_data: Dict[str, Any], output_dir: str = "app/static/cd_flux") -> str:
+        """
+        将计算结果导出为CSV文件
+        
+        @param results_data: 计算结果数据
+        @param output_dir: 输出目录
+        @returns: CSV文件路径
+        """
+        try:
+            # 确保输出目录存在
+            os.makedirs(output_dir, exist_ok=True)
+            
+            # 生成时间戳
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            
+            # 生成文件名
+            calculation_type = results_data.get("calculation_type", "flux_removal")
+            area = results_data.get("area", "unknown")
+            filename = f"{calculation_type}_{area}_{timestamp}.csv"
+            csv_path = os.path.join(output_dir, filename)
+            
+            # 转换为DataFrame
+            results = results_data.get("results", [])
+            if not results:
+                raise ValueError("没有结果数据可导出")
+            
+            df = pd.DataFrame(results)
+            
+            # 保存CSV文件
+            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
+            
+            self.logger.info(f"✓ 成功导出结果到: {csv_path}")
+            return csv_path
+            
+        except Exception as e:
+            self.logger.error(f"导出CSV文件失败: {str(e)}")
+            raise
+    
+    def get_coordinates_for_results(self, results_data: Dict[str, Any]) -> List[Dict[str, Any]]:
+        """
+        获取结果数据对应的坐标信息
+        
+        @param results_data: 计算结果数据
+        @returns: 包含坐标的结果列表
+        """
+        try:
+            results = results_data.get("results", [])
+            if not results:
+                return []
+
+            # 提取成对键,避免 N 次数据库查询
+            farmland_sample_pairs = [(r["farmland_id"], r["sample_id"]) for r in results]
+
+            with SessionLocal() as db:
+                # 使用 farmland_id 分片查询,避免复合 IN 导致的兼容性与参数数量问题
+                wanted_pairs = set(farmland_sample_pairs)
+                unique_farmland_ids = sorted({fid for fid, _ in wanted_pairs})
+
+                def chunk_list(items: List[int], chunk_size: int = 500) -> List[List[int]]:
+                    return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
+
+                rows: List[FarmlandData] = []
+                for id_chunk in chunk_list(unique_farmland_ids, 500):
+                    rows.extend(
+                        db.query(FarmlandData)
+                          .filter(FarmlandData.farmland_id.in_(id_chunk))
+                          .all()
+                    )
+
+                pair_to_farmland = {
+                    (row.farmland_id, row.sample_id): row for row in rows
+                }
+
+                coordinates_results: List[Dict[str, Any]] = []
+                for r in results:
+                    key = (r["farmland_id"], r["sample_id"])
+                    farmland = pair_to_farmland.get(key)
+                    if farmland is None:
+                        continue
+
+                    coord_result = {
+                        "farmland_id": r["farmland_id"],
+                        "sample_id": r["sample_id"],
+                        "longitude": farmland.lon,
+                        "latitude": farmland.lan,
+                        "flux_value": r.get("grain_removal_flux") or r.get("straw_removal_flux")
+                    }
+                    coord_result.update(r)
+                    coordinates_results.append(coord_result)
+
+                self.logger.info(f"✓ 成功获取 {len(coordinates_results)} 个样点的坐标信息(分片批量查询)")
+                return coordinates_results
+                
+        except Exception as e:
+            self.logger.error(f"获取坐标信息失败: {str(e)}")
+            raise
+    
+    def create_flux_visualization(self, area: str, calculation_type: str,
+                                results_with_coords: List[Dict[str, Any]],
+                                level: str = None,
+                                output_dir: str = "app/static/cd_flux",
+                                template_raster: str = "app/static/cd_flux/meanTemp.tif",
+                                boundary_shp: str = None,
+                                colormap: str = "green_yellow_red_purple",
+                                resolution_factor: float = 4.0,
+                                enable_interpolation: bool = False,
+                                cleanup_intermediate: bool = True) -> Dict[str, str]:
+        """
+        创建Cd通量移除可视化图表
+        
+        @param area: 地区名称
+        @param calculation_type: 计算类型(grain_removal 或 straw_removal)
+        @param results_with_coords: 包含坐标的结果数据
+        @param output_dir: 输出目录
+        @param template_raster: 模板栅格文件路径
+        @param boundary_shp: 边界shapefile路径
+        @param colormap: 色彩方案
+        @param resolution_factor: 分辨率因子
+        @param enable_interpolation: 是否启用空间插值
+        @returns: 生成的图片文件路径字典
+        """
+        try:
+            if not results_with_coords:
+                raise ValueError("没有包含坐标的结果数据")
+            
+            # 确保输出目录存在
+            os.makedirs(output_dir, exist_ok=True)
+            generated_files: Dict[str, str] = {}
+            
+            # 生成时间戳
+            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+            
+            # 创建CSV文件用于绘图
+            csv_filename = f"{calculation_type}_{area}_temp_{timestamp}.csv"
+            csv_path = os.path.join(output_dir, csv_filename)
+            
+            # 准备绘图数据
+            plot_data = []
+            for result in results_with_coords:
+                plot_data.append({
+                    "longitude": result["longitude"],
+                    "latitude": result["latitude"],
+                    "flux_value": result["flux_value"]
+                })
+            
+            # 保存为CSV
+            df = pd.DataFrame(plot_data)
+            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
+            
+            # 初始化绘图工具
+            mapper = MappingUtils()
+            
+            # 生成输出文件路径
+            map_output = os.path.join(output_dir, f"{calculation_type}_{area}_map_{timestamp}")
+            histogram_output = os.path.join(output_dir, f"{calculation_type}_{area}_histogram_{timestamp}")
+            
+            # 检查模板文件是否存在
+            if not os.path.exists(template_raster):
+                self.logger.warning(f"模板栅格文件不存在: {template_raster}")
+                template_raster = None
+            
+            # 动态获取边界文件(严格使用指定层级)
+            if not level:
+                raise ValueError("必须提供行政层级 level:county | city | province")
+            boundary_shp = self._get_boundary_file_for_area(area, level)
+            if not boundary_shp:
+                self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪")
+            else:
+                # 在绘图前进行样点边界包含性统计
+                try:
+                    boundary_gdf = gpd.read_file(boundary_shp)
+                    if boundary_gdf is not None and len(boundary_gdf) > 0:
+                        boundary_union = boundary_gdf.unary_union
+                        total_points = len(results_with_coords)
+                        inside_count = 0
+                        outside_points: List[Dict[str, Any]] = []
+                        for r in results_with_coords:
+                            pt = Point(float(r["longitude"]), float(r["latitude"]))
+                            if boundary_union.contains(pt) or boundary_union.touches(pt):
+                                inside_count += 1
+                            else:
+                                outside_points.append({
+                                    "farmland_id": r.get("farmland_id"),
+                                    "sample_id": r.get("sample_id"),
+                                    "longitude": r.get("longitude"),
+                                    "latitude": r.get("latitude"),
+                                    "flux_value": r.get("flux_value")
+                                })
+
+                        outside_count = total_points - inside_count
+                        inside_pct = (inside_count / total_points * 100.0) if total_points > 0 else 0.0
+
+                        self.logger.info(
+                            f"样点边界检查 - 总数: {total_points}, 边界内: {inside_count} ({inside_pct:.2f}%), 边界外: {outside_count}")
+                        if outside_count > 0:
+                            sample_preview = outside_points[:10]
+                            self.logger.warning(
+                                f"存在 {outside_count} 个样点位于边界之外,绘图时将被掩膜隐藏。示例(最多10条): {sample_preview}")
+
+                        report = {
+                            "area": area,
+                            "level": level,
+                            "calculation_type": calculation_type,
+                            "total_points": total_points,
+                            "inside_points": inside_count,
+                            "outside_points": outside_count,
+                            "inside_percentage": round(inside_pct, 2),
+                            "outside_samples": outside_points
+                        }
+                        os.makedirs(output_dir, exist_ok=True)
+                        report_path = os.path.join(
+                            output_dir,
+                            f"{calculation_type}_{area}_points_boundary_check_{timestamp}.json"
+                        )
+                        with open(report_path, "w", encoding="utf-8") as f:
+                            json.dump(report, f, ensure_ascii=False, indent=2)
+                        generated_files["point_boundary_report"] = report_path
+                    else:
+                        generated_files = {}
+                except Exception as check_err:
+                    self.logger.warning(f"样点边界包含性检查失败: {str(check_err)}")
+                    # 保持已有 generated_files,不覆盖
+            
+            # 创建shapefile
+            shapefile_path = csv_path.replace('.csv', '_points.shp')
+            mapper.csv_to_shapefile(csv_path, shapefile_path, 
+                                  lon_col='longitude', lat_col='latitude', value_col='flux_value')
+            
+            # 合并已生成文件映射
+            generated_files.update({"csv": csv_path, "shapefile": shapefile_path})
+            
+            # 如果有模板栅格文件,创建栅格地图
+            if template_raster:
+                try:
+                    # 创建栅格
+                    raster_path = csv_path.replace('.csv', '_raster.tif')
+                    raster_path, stats = mapper.vector_to_raster(
+                        shapefile_path, template_raster, raster_path, 'flux_value',
+                        resolution_factor=resolution_factor, boundary_shp=boundary_shp,
+                        interpolation_method='nearest', enable_interpolation=enable_interpolation
+                    )
+                    generated_files["raster"] = raster_path
+                    
+                    # 创建栅格地图 - 使用英文标题避免中文乱码
+                    title_mapping = {
+                        "grain_removal": "Grain Removal Cd Flux",
+                        "straw_removal": "Straw Removal Cd Flux"
+                    }
+                    map_title = title_mapping.get(calculation_type, "Cd Flux Removal")
+                    
+                    map_file = mapper.create_raster_map(
+                        boundary_shp if boundary_shp else None,
+                        raster_path,
+                        map_output,
+                        colormap=colormap,
+                        title=map_title,
+                        output_size=12,
+                        dpi=300,
+                        resolution_factor=4.0,
+                        enable_interpolation=False,
+                        interpolation_method='nearest'
+                    )
+                    generated_files["map"] = map_file
+                    
+                    # 创建直方图 - 使用英文标题避免中文乱码
+                    histogram_title_mapping = {
+                        "grain_removal": "Grain Removal Cd Flux Distribution",
+                        "straw_removal": "Straw Removal Cd Flux Distribution"
+                    }
+                    histogram_title = histogram_title_mapping.get(calculation_type, "Cd Flux Distribution")
+                    
+                    histogram_file = mapper.create_histogram(
+                        raster_path,
+                        f"{histogram_output}.jpg",
+                        title=histogram_title,
+                        xlabel='Cd Flux (g/ha/a)',
+                        ylabel='Frequency Density'
+                    )
+                    generated_files["histogram"] = histogram_file
+                    
+                except Exception as viz_error:
+                    self.logger.warning(f"栅格可视化创建失败: {str(viz_error)}")
+                    # 即使栅格可视化失败,也返回已生成的文件
+            
+            # 清理中间文件(默认开启,仅保留最终可视化)
+            if cleanup_intermediate:
+                try:
+                    self._cleanup_intermediate_files(generated_files, boundary_shp)
+                except Exception as cleanup_err:
+                    self.logger.warning(f"中间文件清理失败: {str(cleanup_err)}")
+
+            self.logger.info(f"✓ 成功创建 {calculation_type} 可视化,生成文件: {list(generated_files.keys())}")
+            return generated_files
+            
+        except Exception as e:
+            self.logger.error(f"创建可视化失败: {str(e)}")
+            raise
+
+    def _cleanup_intermediate_files(self, generated_files: Dict[str, str], boundary_shp: Optional[str]) -> None:
+        """
+        清理中间文件:CSV、Shapefile 及其配套文件、栅格TIFF;若边界为临时目录,则一并删除
+        """
+        import shutil
+        import tempfile
+
+        def _safe_remove(path: str) -> None:
+            try:
+                if path and os.path.exists(path) and os.path.isfile(path):
+                    os.remove(path)
+            except Exception:
+                pass
+
+        # 删除 CSV
+        _safe_remove(generated_files.get("csv"))
+
+        # 删除栅格
+        _safe_remove(generated_files.get("raster"))
+
+        # 删除 Shapefile 全家桶
+        shp_path = generated_files.get("shapefile")
+        if shp_path:
+            base, _ = os.path.splitext(shp_path)
+            for ext in (".shp", ".shx", ".dbf", ".prj", ".cpg"):
+                _safe_remove(base + ext)
+
+        # 如果边界文件来自系统临时目录,删除其所在目录
+        if boundary_shp:
+            temp_root = tempfile.gettempdir()
+            try:
+                if os.path.commonprefix([os.path.abspath(boundary_shp), temp_root]) == temp_root:
+                    temp_dir = os.path.dirname(os.path.abspath(boundary_shp))
+                    if os.path.isdir(temp_dir):
+                        shutil.rmtree(temp_dir, ignore_errors=True)
+            except Exception:
+                pass
+    
+    def _get_boundary_file_for_area(self, area: str, level: str) -> Optional[str]:
+        """
+        为指定地区获取边界文件
+        
+        @param area: 地区名称
+        @returns: 边界文件路径或None
+        """
+        try:
+            # 仅从数据库严格获取边界(按指定层级精确匹配)
+            boundary_path = self._create_boundary_from_database(area, level)
+            if boundary_path:
+                return boundary_path
+            
+            # 如果都没有找到,记录警告但不使用默认文件
+            self.logger.warning(f"未找到地区 '{area}' 的专用边界文件,也无法从数据库获取")
+            return None
+            
+        except Exception as e:
+            self.logger.error(f"获取边界文件失败: {str(e)}")
+            return None
+    
+    def _create_boundary_from_database(self, area: str, level: str) -> Optional[str]:
+        """
+        从数据库获取边界数据并创建临时shapefile
+        
+        @param area: 地区名称
+        @returns: 临时边界文件路径或None
+        """
+        try:
+            with SessionLocal() as db:
+                norm_area = area.strip()
+                boundary_geojson = get_boundary_geojson_by_name(db, norm_area, level=level)
+                if boundary_geojson:
+                    geometry_obj = shape(boundary_geojson["geometry"])
+                    gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry_obj], crs="EPSG:4326")
+                    temp_dir = tempfile.mkdtemp()
+                    boundary_path = os.path.join(temp_dir, f"{norm_area}_boundary.shp")
+                    gdf.to_file(boundary_path, driver="ESRI Shapefile")
+                    self.logger.info(f"从数据库创建边界文件: {boundary_path}")
+                    return boundary_path
+                    
+        except Exception as e:
+            self.logger.warning(f"从数据库创建边界文件失败: {str(e)}")
+            
+        return None
+    
  

+ 100 - 3
app/utils/mapping_utils.py

@@ -12,10 +12,12 @@ import pandas as pd
 import geopandas as gpd
 from shapely.geometry import Point
 import rasterio
-from rasterio.features import rasterize
+from rasterio.features import rasterize, geometry_mask
 from rasterio.transform import from_origin
 from rasterio.mask import mask
 from rasterio.plot import show
+from rasterio.warp import transform_bounds, reproject
+from rasterio.enums import Resampling
 import numpy as np
 import os
 import json
@@ -259,7 +261,8 @@ class MappingUtils:
                 dtype='float32'
             )
             
-            # 应用边界掩膜(如果提供)
+            # 预备掩膜:优先使用行政区边界;若未提供边界,则使用点集凸包限制绘制范围
+            boundary_mask = None
             if boundary_shp and os.path.exists(boundary_shp):
                 self.logger.info(f"应用边界掩膜: {boundary_shp}")
                 boundary_gdf = gpd.read_file(boundary_shp)
@@ -267,11 +270,24 @@ class MappingUtils:
                     boundary_gdf = boundary_gdf.to_crs(crs)
                 boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
                 raster[~boundary_mask] = np.nan
+            else:
+                try:
+                    # 使用点集凸包作为默认掩膜,避免边界外着色
+                    hull = gdf.unary_union.convex_hull
+                    hull_gdf = gpd.GeoDataFrame(geometry=[hull], crs=crs)
+                    boundary_mask = self.create_boundary_mask(raster, transform, hull_gdf)
+                    raster[~boundary_mask] = np.nan
+                    self.logger.info("已使用点集凸包限制绘制范围")
+                except Exception as hull_err:
+                    self.logger.warning(f"生成点集凸包掩膜失败,可能会出现边界外着色: {str(hull_err)}")
             
             # 使用插值方法填充NaN值(如果启用)
             if enable_interpolation and np.isnan(raster).any():
                 self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
                 raster = self.interpolate_nan_values(raster, method=interpolation_method)
+                # 关键修正:插值后再次应用掩膜,确保边界外不被填充
+                if boundary_mask is not None:
+                    raster[~boundary_mask] = np.nan
             elif not enable_interpolation and np.isnan(raster).any():
                 self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
             
@@ -373,7 +389,13 @@ class MappingUtils:
             # 应用分辨率因子重采样
             if resolution_factor != 1.0:
                 self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
-                raster, out_transform = self._resample_raster(raster, out_transform, resolution_factor, original_crs)
+                raster, out_transform = self._resample_raster(
+                    raster=raster,
+                    transform=out_transform,
+                    resolution_factor=resolution_factor,
+                    crs=original_crs,
+                    resampling='nearest'
+                )
 
             # 应用空间插值(如果启用)
             if enable_interpolation and np.isnan(raster).any():
@@ -408,6 +430,23 @@ class MappingUtils:
 
             # 绘图
             fig, ax = plt.subplots(figsize=fig_size)
+            
+            # 如果有边界文件,需要进一步mask边界外的区域;否则使用栅格有效范围
+            if gdf is not None:
+                try:
+                    height, width = raster.shape
+                    transform = out_transform
+                    geom_mask = geometry_mask(
+                        [json.loads(gdf.to_json())["features"][0]["geometry"]], 
+                        out_shape=(height, width), 
+                        transform=transform,
+                        invert=True
+                    )
+                    raster = np.where(geom_mask, raster, np.nan)
+                except Exception as mask_err:
+                    self.logger.warning(f"边界掩膜应用失败,将继续绘制已裁剪栅格: {str(mask_err)}")
+            
+            # 显示栅格数据
             show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
 
             # 添加矢量边界
@@ -449,6 +488,64 @@ class MappingUtils:
         except Exception as e:
             self.logger.error(f"栅格地图创建失败: {str(e)}")
             raise
+
+    def _resample_raster(self, raster, transform, resolution_factor: float, crs, resampling: str = 'nearest'):
+        """
+        按分辨率因子对二维栅格进行重采样,并返回新栅格与更新后的仿射变换
+        
+        @param raster: 2D numpy 数组
+        @param transform: 输入栅格的仿射变换
+        @param resolution_factor: 分辨率因子 (>1 增加像元密度)
+        @param crs: 坐标参考系
+        @param resampling: 重采样方式 ('nearest' | 'bilinear' | 'cubic')
+        @return: (resampled_raster, new_transform)
+        """
+        try:
+            if resolution_factor == 1.0:
+                return raster, transform
+
+            rows, cols = raster.shape
+            new_rows = max(1, int(rows * resolution_factor))
+            new_cols = max(1, int(cols * resolution_factor))
+
+            # 更新变换(像元变小)
+            new_transform = rasterio.Affine(
+                transform.a / resolution_factor,
+                transform.b,
+                transform.c,
+                transform.d,
+                transform.e / resolution_factor,
+                transform.f
+            )
+
+            # 选择重采样算法
+            resampling_map = {
+                'nearest': Resampling.nearest,
+                'bilinear': Resampling.bilinear,
+                'cubic': Resampling.cubic,
+            }
+            resampling_enum = resampling_map.get(resampling, Resampling.nearest)
+
+            destination = np.full((new_rows, new_cols), np.nan, dtype='float32')
+
+            # 使用 reproject 做重采样(坐标系不变,仅分辨率变化)
+            reproject(
+                source=raster,
+                destination=destination,
+                src_transform=transform,
+                src_crs=crs,
+                dst_transform=new_transform,
+                dst_crs=crs,
+                src_nodata=np.nan,
+                dst_nodata=np.nan,
+                resampling=resampling_enum
+            )
+
+            return destination, new_transform
+        except Exception as e:
+            self.logger.error(f"重采样失败: {str(e)}")
+            # 失败则返回原始数据,避免中断
+            return raster, transform
     
     def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
                         xlabel='像元值', ylabel='频率密度', title='数值分布图',

+ 56 - 0
scripts/demos/db_backup_tool.bat

@@ -0,0 +1,56 @@
+@echo off
+setlocal enabledelayedexpansion
+
+:: Set database configuration
+set DB_NAME=soilgd
+set DB_USER=postgres
+set BACKUP_DIR=backups
+
+:: Generate timestamp (format: YYYYMMDD_HHMMSS)
+:: Use simple PowerShell command to get timestamp
+for /f %%i in ('powershell -command "Get-Date -Format 'yyyyMMdd_HHmmss'"') do set TIMESTAMP=%%i
+
+echo =====================================
+echo   PostgreSQL Database Backup Tool
+echo =====================================
+echo.
+
+:: Create backup directory if not exists
+if not exist "%BACKUP_DIR%" (
+    echo Creating backup directory: %BACKUP_DIR%
+    mkdir "%BACKUP_DIR%"
+)
+
+:: Set backup filename
+set BACKUP_FILE=%BACKUP_DIR%\%DB_NAME%_backup_%TIMESTAMP%.dump
+
+echo Starting database backup: %DB_NAME%
+echo Backup file: %BACKUP_FILE%
+echo.
+
+:: Execute backup command
+echo Executing backup...
+pg_dump -U %DB_USER% -Fc %DB_NAME% > "%BACKUP_FILE%"
+
+:: Check if backup was successful
+if %ERRORLEVEL% EQU 0 (
+    echo.
+    echo Backup completed successfully!
+    echo Backup file saved to: %BACKUP_FILE%
+    
+    :: Display file size
+    for %%A in ("%BACKUP_FILE%") do (
+        echo Backup file size: %%~zA bytes
+    )
+) else (
+    echo.
+    echo Backup failed! Please check:
+    echo   1. Is PostgreSQL service running?
+    echo   2. Does database %DB_NAME% exist?
+    echo   3. Does user %DB_USER% have permissions?
+    echo   4. Is pg_dump command in system PATH?
+)
+
+echo.
+echo Backup operation completed.
+pause

+ 107 - 0
scripts/demos/db_restore_tool.bat

@@ -0,0 +1,107 @@
+@echo off
+setlocal enabledelayedexpansion
+
+:: Set database configuration
+set DB_NAME=soilgd
+set DB_USER=postgres
+
+:: Get project root directory
+set SCRIPT_DIR=%~dp0
+for %%i in ("%SCRIPT_DIR%..\..\") do set PROJECT_ROOT=%%~fi
+set BACKUP_FILE=%PROJECT_ROOT%soilgd.dump
+
+echo =====================================
+echo   PostgreSQL Database Restore Tool
+echo =====================================
+echo.
+
+:: Check if backup file exists
+if not exist "%BACKUP_FILE%" (
+    echo Error: Backup file not found: %BACKUP_FILE%
+    echo Please ensure soilgd.dump exists in project root directory.
+    echo Project root: %PROJECT_ROOT%
+    pause
+    exit /b 1
+)
+
+echo Will restore database from backup file:
+echo Backup file: %BACKUP_FILE%
+echo.
+
+set FULL_BACKUP_PATH=%BACKUP_FILE%
+
+echo =====================================
+echo WARNING: This will delete existing database and recreate it!
+echo Database: %DB_NAME%
+echo Backup file: %FULL_BACKUP_PATH%
+echo =====================================
+echo.
+echo Continue? (Y/N)
+set /p CONFIRM=
+if /i not "%CONFIRM%"=="Y" (
+    echo Operation cancelled.
+    pause
+    exit /b 0
+)
+
+echo.
+echo Starting restore process...
+
+:: Step 1: Terminate active connections and drop existing database
+echo 1. Terminating active connections to database %DB_NAME%...
+psql -U %DB_USER% -d postgres -c "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '%DB_NAME%' AND pid <> pg_backend_pid();" 2>nul
+
+echo 2. Dropping existing database %DB_NAME%...
+dropdb -U %DB_USER% %DB_NAME% 2>nul
+if %ERRORLEVEL% EQU 0 (
+    echo   Database dropped successfully
+) else (
+    echo   Database may not exist or cannot be dropped (this is usually normal)
+)
+
+:: Step 2: Create new empty database
+echo 3. Creating new database %DB_NAME%...
+createdb -U %DB_USER% %DB_NAME% 2>nul
+if %ERRORLEVEL% NEQ 0 (
+    echo   Database creation failed, attempting to force drop and recreate...
+    
+    :: Force drop database with more aggressive approach
+    psql -U %DB_USER% -d postgres -c "DROP DATABASE IF EXISTS %DB_NAME%;" 2>nul
+    
+    :: Try creating again
+    createdb -U %DB_USER% %DB_NAME%
+    if %ERRORLEVEL% NEQ 0 (
+        echo   Database creation still failed!
+        echo   Possible causes:
+        echo     1. PostgreSQL service not running
+        echo     2. Insufficient user permissions
+        echo     3. Database is in use by other processes
+        echo   Please check and retry.
+        pause
+        exit /b 1
+    )
+)
+echo   Database created successfully
+
+:: Step 3: Restore data
+echo 4. Restoring data from backup file...
+pg_restore -U %DB_USER% -d %DB_NAME% "%FULL_BACKUP_PATH%"
+if %ERRORLEVEL% EQU 0 (
+    echo   Data restored successfully!
+) else (
+    echo   Data restore failed! Please check:
+    echo     1. Is backup file corrupted?
+    echo     2. Is PostgreSQL service running?
+    echo     3. Are user permissions correct?
+    pause
+    exit /b 1
+)
+
+echo.
+echo =====================================
+echo Database restore completed successfully!
+echo Database: %DB_NAME%
+echo Backup source: %FULL_BACKUP_PATH%
+echo =====================================
+echo.
+pause

BIN
soilgd.dump