Просмотр исходного кода

优化日志配置,避免重复设置;在边界服务中添加获取GeoDataFrame的方法,减少临时文件创建;更新相关服务逻辑以直接使用GeoDataFrame,提升性能和可维护性。

drggboy 6 дней назад
Родитель
Сommit
18b3bc3a39

+ 3 - 2
app/database.py

@@ -5,8 +5,9 @@ import os
 from dotenv import load_dotenv # type: ignore
 import logging
 from sqlalchemy.exc import SQLAlchemyError
-# 配置日志系统
-logging.basicConfig(level=logging.INFO)
+# 配置日志系统 - 检查是否已经配置过,避免重复配置
+if not logging.getLogger().handlers:
+    logging.basicConfig(level=logging.INFO)
 # 关闭SQLAlchemy的详细SQL执行日志,只保留错误日志
 logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
 logger = logging.getLogger(__name__)

+ 3 - 2
app/main.py

@@ -6,8 +6,9 @@ import logging
 import sys
 import os
 
-# 设置日志
-logging.basicConfig(level=logging.INFO)
+# 设置日志 - 检查是否已经配置过,避免重复配置
+if not logging.getLogger().handlers:
+    logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 def safe_create_tables():

+ 44 - 0
app/services/admin_boundary_service.py

@@ -1,6 +1,9 @@
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import text
 import json
+import geopandas as gpd
+from shapely.geometry import shape
+from typing import Optional
 
 
 def get_boundary_geojson_by_name(db: Session, name: str, level: str = "auto") -> dict:
@@ -112,3 +115,44 @@ def get_boundary_geojson_by_name(db: Session, name: str, level: str = "auto") ->
     raise ValueError(f"未找到名称: {name}")
 
 
+def get_boundary_gdf_by_name(db: Session, name: str, level: str = "auto") -> Optional[gpd.GeoDataFrame]:
+    """根据名称获取边界GeoDataFrame,优化版本避免创建临时文件
+    
+    这是一个通用的边界数据获取函数,直接返回GeoDataFrame而不需要创建临时Shapefile文件。
+    建议在需要边界数据进行空间操作的场景中使用此函数。
+
+    Args:
+        db (Session): 数据库会话
+        name (str): 名称(县/市/省)
+        level (str): 层级,可选 "county"|"city"|"province"|"auto"
+
+    Returns:
+        Optional[gpd.GeoDataFrame]: 边界GeoDataFrame,如果未找到则返回None
+        
+    Example:
+        ```python
+        with SessionLocal() as db:
+            boundary_gdf = get_boundary_gdf_by_name(db, "乐昌市", "county")
+            if boundary_gdf is not None:
+                # 直接使用GeoDataFrame进行空间操作
+                boundary_union = boundary_gdf.unary_union
+        ```
+    """
+    try:
+        boundary_geojson = get_boundary_geojson_by_name(db, name, level)
+        if boundary_geojson:
+            geometry_obj = shape(boundary_geojson["geometry"])
+            gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], 
+                                 geometry=[geometry_obj], 
+                                 crs="EPSG:4326")
+            return gdf
+    except ValueError:
+        # 未找到对应边界数据
+        pass
+    except Exception:
+        # 其他错误也返回None,让调用方处理
+        pass
+    
+    return None
+
+

+ 37 - 9
app/services/cd_flux_removal_service.py

@@ -18,7 +18,7 @@ from ..models.parameters import Parameters
 from ..models.CropCd_output import CropCdOutputData
 from ..models.farmland import FarmlandData
 from ..utils.mapping_utils import MappingUtils
-from .admin_boundary_service import get_boundary_geojson_by_name
+from .admin_boundary_service import get_boundary_geojson_by_name, get_boundary_gdf_by_name
 import geopandas as gpd
 from shapely.geometry import shape, Point
 import tempfile
@@ -370,16 +370,20 @@ class CdFluxRemovalService:
                 self.logger.warning(f"模板栅格文件不存在: {template_raster}")
                 template_raster = None
             
-            # 动态获取边界文件(严格使用指定层级)
+            # 动态获取边界数据(严格使用指定层级)
             if not level:
                 raise ValueError("必须提供行政层级 level:county | city | province")
-            boundary_shp = self._get_boundary_file_for_area(area, level)
-            if not boundary_shp:
-                self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪")
+            
+            # 优化:直接从数据库获取边界GeoDataFrame,避免创建临时shapefile文件
+            # 这样可以减少磁盘I/O操作和临时文件管理的开销
+            boundary_gdf = self._get_boundary_gdf_from_database(area, level)
+            boundary_shp = None  # 不再需要临时边界文件
+            
+            if boundary_gdf is None:
+                self.logger.warning(f"未找到地区 '{area}' 的边界数据,将不使用边界裁剪")
             else:
                 # 在绘图前进行样点边界包含性统计
                 try:
-                    boundary_gdf = gpd.read_file(boundary_shp)
                     if boundary_gdf is not None and len(boundary_gdf) > 0:
                         boundary_union = boundary_gdf.unary_union
                         total_points = len(results_with_coords)
@@ -438,7 +442,7 @@ class CdFluxRemovalService:
                     raster_path = csv_path.replace('.csv', '_raster.tif')
                     raster_path, stats = mapper.vector_to_raster(
                         shapefile_path, template_raster, raster_path, 'flux_value',
-                        resolution_factor=resolution_factor, boundary_shp=boundary_shp,
+                        resolution_factor=resolution_factor, boundary_shp=boundary_shp, boundary_gdf=boundary_gdf,
                         interpolation_method='nearest', enable_interpolation=enable_interpolation
                     )
                     generated_files["raster"] = raster_path
@@ -460,7 +464,8 @@ class CdFluxRemovalService:
                         dpi=300,
                         resolution_factor=4.0,
                         enable_interpolation=False,
-                        interpolation_method='nearest'
+                        interpolation_method='nearest',
+                        boundary_gdf=boundary_gdf
                     )
                     generated_files["map"] = map_file
                     
@@ -487,7 +492,8 @@ class CdFluxRemovalService:
             # 清理中间文件(默认开启,仅保留最终可视化)
             if cleanup_intermediate:
                 try:
-                    self._cleanup_intermediate_files(generated_files, boundary_shp)
+                    # 由于不再创建临时边界文件,所以传递None
+                    self._cleanup_intermediate_files(generated_files, None)
                 except Exception as cleanup_err:
                     self.logger.warning(f"中间文件清理失败: {str(cleanup_err)}")
 
@@ -557,12 +563,34 @@ class CdFluxRemovalService:
             self.logger.error(f"获取边界文件失败: {str(e)}")
             return None
     
+    def _get_boundary_gdf_from_database(self, area: str, level: str) -> Optional[gpd.GeoDataFrame]:
+        """
+        直接从数据库获取边界数据作为GeoDataFrame
+        
+        @param area: 地区名称
+        @param level: 行政层级
+        @returns: 边界GeoDataFrame或None
+        """
+        try:
+            with SessionLocal() as db:
+                norm_area = area.strip()
+                boundary_gdf = get_boundary_gdf_by_name(db, norm_area, level=level)
+                if boundary_gdf is not None:
+                    self.logger.info(f"从数据库获取边界数据: {norm_area} ({level})")
+                return boundary_gdf
+                    
+        except Exception as e:
+            self.logger.warning(f"从数据库获取边界数据失败: {str(e)}")
+            
+        return None
+
     def _create_boundary_from_database(self, area: str, level: str) -> Optional[str]:
         """
         从数据库获取边界数据并创建临时shapefile
         
         @param area: 地区名称
         @returns: 临时边界文件路径或None
+        @deprecated: 建议使用 _get_boundary_gdf_from_database 方法直接获取 GeoDataFrame
         """
         try:
             with SessionLocal() as db:

+ 51 - 18
app/utils/mapping_utils.py

@@ -65,11 +65,14 @@ class MappingUtils:
         self.logger = logging.getLogger(self.__class__.__name__)
         self.logger.setLevel(log_level)
         
+        # 避免重复添加处理器,并防止日志传播到父级处理器导致重复输出
         if not self.logger.handlers:
             handler = logging.StreamHandler()
             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
             handler.setFormatter(formatter)
             self.logger.addHandler(handler)
+            # 关闭日志传播,避免与全局basicConfig冲突
+            self.logger.propagate = False
     
     def csv_to_shapefile(self, csv_file, shapefile_output, lon_col=0, lat_col=1, value_col=2):
         """
@@ -200,7 +203,7 @@ class MappingUtils:
             return raster
     
     def vector_to_raster(self, input_shapefile, template_tif, output_tif, field, 
-                        resolution_factor=16.0, boundary_shp=None, interpolation_method='nearest', enable_interpolation=True):
+                        resolution_factor=16.0, boundary_shp=None, boundary_gdf=None, interpolation_method='nearest', enable_interpolation=True):
         """
         将点矢量数据转换为栅格数据
         
@@ -209,14 +212,16 @@ class MappingUtils:
         @param output_tif: 输出栅格化后的GeoTIFF文件路径
         @param field: 用于栅格化的属性字段名
         @param resolution_factor: 分辨率倍数因子
-        @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜
+        @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜(兼容性保留)
+        @param boundary_gdf: 边界GeoDataFrame,优先使用此参数而非boundary_shp
         @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
         @param enable_interpolation: 是否启用空间插值,默认True
         @return: 输出的GeoTIFF文件路径和统计信息
         """
         try:
             self.logger.info(f"开始处理: {input_shapefile}")
-            self.logger.info(f"分辨率因子: {resolution_factor}, 插值方法: {interpolation_method}")
+            interpolation_status = "启用" if enable_interpolation else "禁用"
+            self.logger.info(f"分辨率因子: {resolution_factor}, 插值设置: {interpolation_status} (方法: {interpolation_method})")
             
             # 读取矢量数据
             gdf = gpd.read_file(input_shapefile)
@@ -263,13 +268,20 @@ class MappingUtils:
             
             # 预备掩膜:优先使用行政区边界;若未提供边界,则使用点集凸包限制绘制范围
             boundary_mask = None
-            if boundary_shp and os.path.exists(boundary_shp):
-                self.logger.info(f"应用边界掩膜: {boundary_shp}")
-                boundary_gdf = gpd.read_file(boundary_shp)
+            if boundary_gdf is not None:
+                self.logger.info("应用边界掩膜: 使用直接提供的GeoDataFrame")
+                # 确保边界GeoDataFrame的CRS与栅格一致
                 if boundary_gdf.crs != crs:
                     boundary_gdf = boundary_gdf.to_crs(crs)
                 boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
                 raster[~boundary_mask] = np.nan
+            elif boundary_shp and os.path.exists(boundary_shp):
+                self.logger.info(f"应用边界掩膜: {boundary_shp}")
+                boundary_gdf_from_file = gpd.read_file(boundary_shp)
+                if boundary_gdf_from_file.crs != crs:
+                    boundary_gdf_from_file = boundary_gdf_from_file.to_crs(crs)
+                boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf_from_file)
+                raster[~boundary_mask] = np.nan
             else:
                 try:
                     # 使用点集凸包作为默认掩膜,避免边界外着色
@@ -281,15 +293,26 @@ class MappingUtils:
                 except Exception as hull_err:
                     self.logger.warning(f"生成点集凸包掩膜失败,可能会出现边界外着色: {str(hull_err)}")
             
+            # 检查栅格数据状态并决定是否插值
+            nan_count = np.isnan(raster).sum()
+            total_pixels = raster.size
+            self.logger.info(f"栅格数据状态: 总像素数 {total_pixels}, NaN像素数 {nan_count} ({nan_count/total_pixels*100:.1f}%)")
+            
             # 使用插值方法填充NaN值(如果启用)
-            if enable_interpolation and np.isnan(raster).any():
-                self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
+            if enable_interpolation and nan_count > 0:
+                self.logger.info(f"✓ 启用插值: 使用 {interpolation_method} 方法填充 {nan_count} 个NaN像素...")
                 raster = self.interpolate_nan_values(raster, method=interpolation_method)
                 # 关键修正:插值后再次应用掩膜,确保边界外不被填充
                 if boundary_mask is not None:
                     raster[~boundary_mask] = np.nan
-            elif not enable_interpolation and np.isnan(raster).any():
-                self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
+                final_nan_count = np.isnan(raster).sum()
+                self.logger.info(f"插值完成: 剩余NaN像素数 {final_nan_count}")
+            elif enable_interpolation and nan_count == 0:
+                self.logger.info("✓ 插值已启用,但栅格数据无NaN值,无需插值")
+            elif not enable_interpolation and nan_count > 0:
+                self.logger.info(f"✗ 插值已禁用,保留 {nan_count} 个NaN像素")
+            else:
+                self.logger.info("✓ 栅格数据完整,无需插值")
             
             # 创建输出目录
             os.makedirs(os.path.dirname(output_tif), exist_ok=True)
@@ -336,11 +359,11 @@ class MappingUtils:
                          colormap='green_yellow_red_purple', title="Prediction Map",
                          output_size=12, figsize=None, dpi=300,
                          resolution_factor=1.0, enable_interpolation=True,
-                         interpolation_method='nearest'):
+                         interpolation_method='nearest', boundary_gdf=None):
         """
         创建栅格地图
         
-        @param shp_path: 输入的矢量数据路径
+        @param shp_path: 输入的矢量数据路径(兼容性保留)
         @param tif_path: 输入的栅格数据路径
         @param output_path: 输出图片路径(不包含扩展名)
         @param colormap: 色彩方案名称或颜色列表
@@ -351,14 +374,23 @@ class MappingUtils:
         @param resolution_factor: 分辨率因子,>1提高分辨率,<1降低分辨率
         @param enable_interpolation: 是否启用空间插值,用于处理NaN值或提高分辨率,默认True
         @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
-                @return: 输出图片文件路径
+        @param boundary_gdf: 边界GeoDataFrame(可选,优先使用)
+        @return: 输出图片文件路径
         """
         try:
             self.logger.info(f"开始创建栅格地图: {tif_path}")
             self.logger.info(f"分辨率因子: {resolution_factor}, 启用插值: {enable_interpolation}")
 
-            # 读取矢量边界
-            gdf = gpd.read_file(shp_path) if shp_path else None
+            # 读取矢量边界:优先使用boundary_gdf,否则从shp_path读取
+            if boundary_gdf is not None:
+                gdf = boundary_gdf
+                self.logger.info("使用直接提供的边界GeoDataFrame")
+            elif shp_path:
+                gdf = gpd.read_file(shp_path)
+                self.logger.info(f"从文件读取边界数据: {shp_path}")
+            else:
+                gdf = None
+                self.logger.info("未提供边界数据,将使用整个栅格范围")
 
             # 读取并裁剪栅格数据
             with rasterio.open(tif_path) as src:
@@ -628,7 +660,7 @@ def get_available_colormaps():
 
 
 def csv_to_raster_workflow(csv_file, template_tif, output_dir, 
-                          boundary_shp=None, resolution_factor=16.0,
+                          boundary_shp=None, boundary_gdf=None, resolution_factor=16.0,
                           interpolation_method='nearest', field_name='Prediction',
                           lon_col=0, lat_col=1, value_col=2, enable_interpolation=False):
     """
@@ -637,7 +669,8 @@ def csv_to_raster_workflow(csv_file, template_tif, output_dir,
     @param csv_file: CSV文件路径
     @param template_tif: 模板GeoTIFF文件路径
     @param output_dir: 输出目录
-    @param boundary_shp: 边界Shapefile文件路径(可选)
+    @param boundary_shp: 边界Shapefile文件路径(可选,兼容性保留)
+    @param boundary_gdf: 边界GeoDataFrame(可选,优先使用)
     @param resolution_factor: 分辨率因子
     @param interpolation_method: 插值方法
     @param field_name: 字段名称
@@ -663,7 +696,7 @@ def csv_to_raster_workflow(csv_file, template_tif, output_dir,
     # 2. Shapefile转栅格
     raster_path, stats = mapper.vector_to_raster(
         shapefile_path, template_tif, raster_path, field_name,
-        resolution_factor, boundary_shp, interpolation_method, enable_interpolation
+        resolution_factor, boundary_shp, boundary_gdf, interpolation_method, enable_interpolation
     )
     
     return {