Browse Source

优化土地数据处理逻辑,新增临时文件清理功能,支持高效的最近邻搜索算法,提升数据处理性能和资源管理。同时更新日志记录,确保输出文件存在性验证和错误处理更加完善。

drggboy 1 week ago
parent
commit
2fd21973ab
2 changed files with 172 additions and 23 deletions
  1. 4 2
      app/api/water.py
  2. 168 21
      app/services/water_service.py

+ 4 - 2
app/api/water.py

@@ -127,7 +127,8 @@ async def recalculate_land_data(
         enable_interpolation: Optional[bool] = Form(False, description="是否启用空间插值,默认启用"),
         enable_interpolation: Optional[bool] = Form(False, description="是否启用空间插值,默认启用"),
         interpolation_method: Optional[str] = Form("linear", description="插值方法: nearest | linear | cubic"),
         interpolation_method: Optional[str] = Form("linear", description="插值方法: nearest | linear | cubic"),
         resolution_factor: Optional[float] = Form(4.0, description="分辨率因子,默认4.0,越大分辨率越高"),
         resolution_factor: Optional[float] = Form(4.0, description="分辨率因子,默认4.0,越大分辨率越高"),
-        save_csv: Optional[bool] = Form(True, description="是否生成CSV文件,默认生成")
+        save_csv: Optional[bool] = Form(True, description="是否生成CSV文件,默认生成"),
+        cleanup_temp_files: Optional[bool] = Form(True, description="是否清理临时文件,默认清理")
 ) -> Dict[str, Any]:
 ) -> Dict[str, Any]:
     """重新计算土地数据并返回结果路径,支持动态边界控制和插值控制"""
     """重新计算土地数据并返回结果路径,支持动态边界控制和插值控制"""
     try:
     try:
@@ -161,7 +162,8 @@ async def recalculate_land_data(
             enable_interpolation=enable_interpolation,
             enable_interpolation=enable_interpolation,
             interpolation_method=interpolation_method,
             interpolation_method=interpolation_method,
             resolution_factor=resolution_factor,
             resolution_factor=resolution_factor,
-            save_csv=save_csv  # 将CSV生成选项传递给处理函数
+            save_csv=save_csv,  # 将CSV生成选项传递给处理函数
+            cleanup_temp_files=cleanup_temp_files  # 将清理选项传递给处理函数
         )
         )
 
 
         if not results:
         if not results:

+ 168 - 21
app/services/water_service.py

@@ -4,12 +4,14 @@ import pandas as pd
 from pyproj import Transformer
 from pyproj import Transformer
 from shapely.geometry import Point
 from shapely.geometry import Point
 import rasterio
 import rasterio
-from typing import Optional, Dict, Any
+from typing import Optional, Dict, Any, List
 from datetime import datetime
 from datetime import datetime
 import numpy as np
 import numpy as np
 import logging
 import logging
 import shutil
 import shutil
 import sys
 import sys
+from sklearn.neighbors import BallTree
+from time import time
 
 
 # 导入MappingUtils
 # 导入MappingUtils
 from ..utils.mapping_utils import MappingUtils, csv_to_raster_workflow, dataframe_to_raster_workflow
 from ..utils.mapping_utils import MappingUtils, csv_to_raster_workflow, dataframe_to_raster_workflow
@@ -86,7 +88,116 @@ def get_boundary_gdf_from_database(area: str, level: str) -> Optional[gpd.GeoDat
     return None
     return None
 
 
 
 
+def find_nearest_sampling_points_optimized(land_centers_df: pd.DataFrame, 
+                                          sampling_points_df: pd.DataFrame) -> np.ndarray:
+    """
+    使用BallTree高效计算每个土地中心点的最近采样点
+    
+    @description: 使用空间索引优化最近邻搜索,将O(n×m)复杂度降低到O(n×log(m))
+    
+    @param land_centers_df: 土地中心点数据,包含center_lon和center_lat列
+    @param sampling_points_df: 采样点数据,包含经度和纬度列
+    @returns: 每个土地中心点对应的最近采样点索引数组
+    """
+    logger.info("开始构建空间索引优化最近邻搜索...")
+    
+    start_time = time()
+    
+    # 1. 准备采样点坐标数据(转换为弧度用于BallTree)
+    sampling_coords = np.radians(sampling_points_df[['经度', '纬度']].values)
+    
+    # 2. 构建BallTree空间索引
+    logger.info(f"构建BallTree索引,采样点数量: {len(sampling_coords)}")
+    tree = BallTree(sampling_coords, metric='haversine')
+    
+    # 3. 准备土地中心点坐标数据
+    land_coords = np.radians(land_centers_df[['center_lon', 'center_lat']].values)
+    
+    # 4. 批量查询最近邻(k=1表示只找最近的一个点)
+    logger.info(f"批量查询最近邻,土地中心点数量: {len(land_coords)}")
+    distances, indices = tree.query(land_coords, k=1)
+    
+    # 5. 提取索引(indices是二维数组,我们只需要第一列)
+    nearest_indices = indices.flatten()
+    
+    elapsed_time = time() - start_time
+    logger.info(f"空间索引搜索完成,耗时: {elapsed_time:.2f}秒")
+    logger.info(f"平均每个点查询时间: {elapsed_time/len(land_coords)*1000:.2f}毫秒")
+    
+    return nearest_indices
 
 
+def cleanup_temporary_files(*file_paths):
+    """
+    清理临时文件
+    
+    @description: 安全地删除指定的临时文件,支持多种文件类型
+    @param file_paths: 要删除的文件路径(可变参数)
+    """
+    import tempfile
+    
+    for file_path in file_paths:
+        if not file_path:
+            continue
+            
+        try:
+            if os.path.exists(file_path) and os.path.isfile(file_path):
+                os.remove(file_path)
+                logger.info(f"已清理临时文件: {os.path.basename(file_path)}")
+                
+                # 如果是shapefile,也删除相关的配套文件
+                if file_path.endswith('.shp'):
+                    base_path = os.path.splitext(file_path)[0]
+                    for ext in ['.shx', '.dbf', '.prj', '.cpg']:
+                        related_file = base_path + ext
+                        if os.path.exists(related_file):
+                            os.remove(related_file)
+                            logger.info(f"已清理相关文件: {os.path.basename(related_file)}")
+                            
+        except Exception as e:
+            logger.warning(f"清理文件失败 {file_path}: {str(e)}")
+
+
+def cleanup_temp_files_in_directory(directory: str, patterns: List[str] = None) -> int:
+    """
+    清理指定目录下的临时文件
+    
+    @description: 根据文件名模式清理目录中的临时文件
+    @param directory: 要清理的目录路径
+    @param patterns: 文件名模式列表,默认为['memory_raster_', 'temp_', 'tmp_']
+    @returns: 清理的文件数量
+    """
+    if patterns is None:
+        patterns = ['memory_raster_', 'temp_', 'tmp_']
+    
+    if not os.path.exists(directory) or not os.path.isdir(directory):
+        logger.warning(f"目录不存在或不是有效目录: {directory}")
+        return 0
+    
+    cleaned_count = 0
+    
+    try:
+        for filename in os.listdir(directory):
+            file_path = os.path.join(directory, filename)
+            
+            # 检查是否是文件
+            if not os.path.isfile(file_path):
+                continue
+                
+            # 检查文件名是否匹配任何模式
+            should_clean = any(pattern in filename for pattern in patterns)
+            
+            if should_clean:
+                try:
+                    os.remove(file_path)
+                    logger.info(f"已清理临时文件: {filename}")
+                    cleaned_count += 1
+                except Exception as e:
+                    logger.warning(f"清理文件失败 {filename}: {str(e)}")
+                    
+    except Exception as e:
+        logger.error(f"清理目录失败 {directory}: {str(e)}")
+        
+    return cleaned_count
 
 
 
 
 # 土地数据处理函数
 # 土地数据处理函数
@@ -107,6 +218,7 @@ def process_land_data(land_type, coefficient_params=None, save_csv=True):
         return None, None, None
         return None, None, None
 
 
     logger.info(f"从数据库获取到 {len(land_centers_df)} 个 '{land_type}' 类型的土地数据")
     logger.info(f"从数据库获取到 {len(land_centers_df)} 个 '{land_type}' 类型的土地数据")
+    logger.info(f"预计需要进行 {len(land_centers_df)} 次最近邻搜索,使用高性能算法处理...")
 
 
     # 读取Excel采样点数据
     # 读取Excel采样点数据
     if not os.path.exists(xls_file):
     if not os.path.exists(xls_file):
@@ -128,26 +240,20 @@ def process_land_data(land_type, coefficient_params=None, save_csv=True):
     Num = param1 * param2
     Num = param1 * param2
     logger.info(f"系数: {param1} * {param2} = {Num}")
     logger.info(f"系数: {param1} * {param2} = {Num}")
 
 
-    # 处理每个面要素,使用数据库中的中心点坐标
-    cd_values = []
-    centers = []
+    # 高效处理:使用空间索引查找最近采样点
+    logger.info("开始高效距离计算和Cd值计算...")
+    start_time = time()
     
     
-    for index, row in land_centers_df.iterrows():
-        center_lon = row['center_lon']
-        center_lat = row['center_lat']
-        centers.append((center_lon, center_lat))
-
-        # 计算到所有采样点的距离
-        distances = df_xls.apply(
-            lambda x: Point(center_lon, center_lat).distance(Point(x['经度'], x['纬度'])),
-            axis=1
-        )
-        min_idx = distances.idxmin()
-        nearest = df_xls.loc[min_idx]
-
-        # 计算Cd含量值
-        cd_value = nearest['Cd (ug/L)'] * Num
-        cd_values.append(cd_value)
+    # 使用优化的空间索引方法查找最近采样点
+    nearest_indices = find_nearest_sampling_points_optimized(land_centers_df, df_xls)
+    
+    # 批量计算Cd含量值
+    centers = list(zip(land_centers_df['center_lon'], land_centers_df['center_lat']))
+    cd_values = df_xls.iloc[nearest_indices]['Cd (ug/L)'].values * Num
+    
+    calculation_time = time() - start_time
+    logger.info(f"Cd值计算完成,耗时: {calculation_time:.2f}秒")
+    logger.info(f"处理了 {len(centers)} 个土地中心点")
 
 
     # 创建简化数据DataFrame
     # 创建简化数据DataFrame
     simplified_data = pd.DataFrame({
     simplified_data = pd.DataFrame({
@@ -297,7 +403,8 @@ def process_land_to_visualization(land_type, coefficient_params=None,
                                   enable_interpolation: Optional[bool] = True,
                                   enable_interpolation: Optional[bool] = True,
                                   interpolation_method: Optional[str] = "linear",
                                   interpolation_method: Optional[str] = "linear",
                                   resolution_factor: Optional[float] = 4.0,
                                   resolution_factor: Optional[float] = 4.0,
-                                  save_csv: Optional[bool] = True):
+                                  save_csv: Optional[bool] = True,
+                                  cleanup_temp_files: Optional[bool] = True):
     """
     """
     完整的土地数据处理可视化流程(使用统一的MappingUtils接口,支持动态边界和插值控制)
     完整的土地数据处理可视化流程(使用统一的MappingUtils接口,支持动态边界和插值控制)
     
     
@@ -321,6 +428,7 @@ def process_land_to_visualization(land_type, coefficient_params=None,
     @param interpolation_method: 插值方法,nearest | linear | cubic,默认linear
     @param interpolation_method: 插值方法,nearest | linear | cubic,默认linear
     @param resolution_factor: 分辨率因子,默认4.0,越大分辨率越高
     @param resolution_factor: 分辨率因子,默认4.0,越大分辨率越高
     @param save_csv: 是否生成CSV文件,默认True
     @param save_csv: 是否生成CSV文件,默认True
+    @param cleanup_temp_files: 是否清理临时文件,默认True
     @returns: 包含所有生成文件路径的元组
     @returns: 包含所有生成文件路径的元组
     """
     """
     base_dir = get_base_dir()
     base_dir = get_base_dir()
@@ -441,6 +549,38 @@ def process_land_to_visualization(land_type, coefficient_params=None,
         data_dir = os.path.join(base_dir, "..", "static", "water", "Data")
         data_dir = os.path.join(base_dir, "..", "static", "water", "Data")
         cleaned_csv = os.path.join(data_dir, f"中心点经纬度与预测值&{land_type}_清洗.csv")
         cleaned_csv = os.path.join(data_dir, f"中心点经纬度与预测值&{land_type}_清洗.csv")
     
     
+    # 清理临时文件(如果启用)
+    if cleanup_temp_files:
+        logger.info("开始清理临时文件...")
+        
+        # 要清理的临时文件列表
+        temp_files_to_cleanup = []
+        
+        # 添加临时栅格文件(如果是memory_raster_开头的)
+        if output_tif and 'memory_raster_' in os.path.basename(output_tif):
+            temp_files_to_cleanup.append(output_tif)
+            
+        # 添加临时shapefile(如果存在且是临时文件)
+        temp_shapefile = workflow_result.get('shapefile')
+        if temp_shapefile and ('temp' in temp_shapefile.lower() or 'memory' in temp_shapefile.lower()):
+            temp_files_to_cleanup.append(temp_shapefile)
+        
+        # 如果不保存CSV,也清理CSV文件
+        if not save_csv and cleaned_csv_path and os.path.exists(cleaned_csv_path):
+            temp_files_to_cleanup.append(cleaned_csv_path)
+            
+        # 执行清理
+        if temp_files_to_cleanup:
+            cleanup_temporary_files(*temp_files_to_cleanup)
+            logger.info(f"已清理 {len(temp_files_to_cleanup)} 个临时文件")
+            
+            # 如果清理了栅格文件,将返回路径设为None以避免引用已删除的文件
+            if output_tif in temp_files_to_cleanup:
+                output_tif = None
+                logger.info("注意:临时栅格文件已被清理,返回的栅格路径为None")
+        else:
+            logger.info("没有临时文件需要清理")
+    
     return cleaned_csv, workflow_result['shapefile'], output_tif, map_output, hist_output, used_coeff
     return cleaned_csv, workflow_result['shapefile'], output_tif, map_output, hist_output, used_coeff
 
 
 
 
@@ -544,6 +684,13 @@ def main():
     except Exception as e:
     except Exception as e:
         logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
         logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
     finally:
     finally:
+        # 清理临时文件
+        base_dir = get_base_dir()
+        raster_dir = os.path.join(base_dir, "..", "static", "water", "Raster")
+        cleaned_count = cleanup_temp_files_in_directory(raster_dir)
+        if cleaned_count > 0:
+            logger.info(f"已清理 {cleaned_count} 个临时文件")
+        
         logger.info("处理完成")
         logger.info("处理完成")