Jelajahi Sumber

优化土地数据处理逻辑,新增临时文件清理功能,支持高效的最近邻搜索算法,提升数据处理性能和资源管理。同时更新日志记录,确保输出文件存在性验证和错误处理更加完善。

drggboy 1 Minggu lalu
induk
melakukan
2fd21973ab
2 mengubah file dengan 172 tambahan dan 23 penghapusan
  1. 4 2
      app/api/water.py
  2. 168 21
      app/services/water_service.py

+ 4 - 2
app/api/water.py

@@ -127,7 +127,8 @@ async def recalculate_land_data(
         enable_interpolation: Optional[bool] = Form(False, description="是否启用空间插值,默认启用"),
         interpolation_method: Optional[str] = Form("linear", description="插值方法: nearest | linear | cubic"),
         resolution_factor: Optional[float] = Form(4.0, description="分辨率因子,默认4.0,越大分辨率越高"),
-        save_csv: Optional[bool] = Form(True, description="是否生成CSV文件,默认生成")
+        save_csv: Optional[bool] = Form(True, description="是否生成CSV文件,默认生成"),
+        cleanup_temp_files: Optional[bool] = Form(True, description="是否清理临时文件,默认清理")
 ) -> Dict[str, Any]:
     """重新计算土地数据并返回结果路径,支持动态边界控制和插值控制"""
     try:
@@ -161,7 +162,8 @@ async def recalculate_land_data(
             enable_interpolation=enable_interpolation,
             interpolation_method=interpolation_method,
             resolution_factor=resolution_factor,
-            save_csv=save_csv  # 将CSV生成选项传递给处理函数
+            save_csv=save_csv,  # 将CSV生成选项传递给处理函数
+            cleanup_temp_files=cleanup_temp_files  # 将清理选项传递给处理函数
         )
 
         if not results:

+ 168 - 21
app/services/water_service.py

@@ -4,12 +4,14 @@ import pandas as pd
 from pyproj import Transformer
 from shapely.geometry import Point
 import rasterio
-from typing import Optional, Dict, Any
+from typing import Optional, Dict, Any, List
 from datetime import datetime
 import numpy as np
 import logging
 import shutil
 import sys
+from sklearn.neighbors import BallTree
+from time import time
 
 # 导入MappingUtils
 from ..utils.mapping_utils import MappingUtils, csv_to_raster_workflow, dataframe_to_raster_workflow
@@ -86,7 +88,116 @@ def get_boundary_gdf_from_database(area: str, level: str) -> Optional[gpd.GeoDat
     return None
 
 
+def find_nearest_sampling_points_optimized(land_centers_df: pd.DataFrame, 
+                                          sampling_points_df: pd.DataFrame) -> np.ndarray:
+    """
+    使用BallTree高效计算每个土地中心点的最近采样点
+    
+    @description: 使用空间索引优化最近邻搜索,将O(n×m)复杂度降低到O(n×log(m))
+    
+    @param land_centers_df: 土地中心点数据,包含center_lon和center_lat列
+    @param sampling_points_df: 采样点数据,包含经度和纬度列
+    @returns: 每个土地中心点对应的最近采样点索引数组
+    """
+    logger.info("开始构建空间索引优化最近邻搜索...")
+    
+    start_time = time()
+    
+    # 1. 准备采样点坐标数据(转换为弧度用于BallTree)
+    sampling_coords = np.radians(sampling_points_df[['经度', '纬度']].values)
+    
+    # 2. 构建BallTree空间索引
+    logger.info(f"构建BallTree索引,采样点数量: {len(sampling_coords)}")
+    tree = BallTree(sampling_coords, metric='haversine')
+    
+    # 3. 准备土地中心点坐标数据
+    land_coords = np.radians(land_centers_df[['center_lon', 'center_lat']].values)
+    
+    # 4. 批量查询最近邻(k=1表示只找最近的一个点)
+    logger.info(f"批量查询最近邻,土地中心点数量: {len(land_coords)}")
+    distances, indices = tree.query(land_coords, k=1)
+    
+    # 5. 提取索引(indices是二维数组,我们只需要第一列)
+    nearest_indices = indices.flatten()
+    
+    elapsed_time = time() - start_time
+    logger.info(f"空间索引搜索完成,耗时: {elapsed_time:.2f}秒")
+    logger.info(f"平均每个点查询时间: {elapsed_time/len(land_coords)*1000:.2f}毫秒")
+    
+    return nearest_indices
 
+def cleanup_temporary_files(*file_paths):
+    """
+    清理临时文件
+    
+    @description: 安全地删除指定的临时文件,支持多种文件类型
+    @param file_paths: 要删除的文件路径(可变参数)
+    """
+    import tempfile
+    
+    for file_path in file_paths:
+        if not file_path:
+            continue
+            
+        try:
+            if os.path.exists(file_path) and os.path.isfile(file_path):
+                os.remove(file_path)
+                logger.info(f"已清理临时文件: {os.path.basename(file_path)}")
+                
+                # 如果是shapefile,也删除相关的配套文件
+                if file_path.endswith('.shp'):
+                    base_path = os.path.splitext(file_path)[0]
+                    for ext in ['.shx', '.dbf', '.prj', '.cpg']:
+                        related_file = base_path + ext
+                        if os.path.exists(related_file):
+                            os.remove(related_file)
+                            logger.info(f"已清理相关文件: {os.path.basename(related_file)}")
+                            
+        except Exception as e:
+            logger.warning(f"清理文件失败 {file_path}: {str(e)}")
+
+
+def cleanup_temp_files_in_directory(directory: str, patterns: List[str] = None) -> int:
+    """
+    清理指定目录下的临时文件
+    
+    @description: 根据文件名模式清理目录中的临时文件
+    @param directory: 要清理的目录路径
+    @param patterns: 文件名模式列表,默认为['memory_raster_', 'temp_', 'tmp_']
+    @returns: 清理的文件数量
+    """
+    if patterns is None:
+        patterns = ['memory_raster_', 'temp_', 'tmp_']
+    
+    if not os.path.exists(directory) or not os.path.isdir(directory):
+        logger.warning(f"目录不存在或不是有效目录: {directory}")
+        return 0
+    
+    cleaned_count = 0
+    
+    try:
+        for filename in os.listdir(directory):
+            file_path = os.path.join(directory, filename)
+            
+            # 检查是否是文件
+            if not os.path.isfile(file_path):
+                continue
+                
+            # 检查文件名是否匹配任何模式
+            should_clean = any(pattern in filename for pattern in patterns)
+            
+            if should_clean:
+                try:
+                    os.remove(file_path)
+                    logger.info(f"已清理临时文件: {filename}")
+                    cleaned_count += 1
+                except Exception as e:
+                    logger.warning(f"清理文件失败 {filename}: {str(e)}")
+                    
+    except Exception as e:
+        logger.error(f"清理目录失败 {directory}: {str(e)}")
+        
+    return cleaned_count
 
 
 # 土地数据处理函数
@@ -107,6 +218,7 @@ def process_land_data(land_type, coefficient_params=None, save_csv=True):
         return None, None, None
 
     logger.info(f"从数据库获取到 {len(land_centers_df)} 个 '{land_type}' 类型的土地数据")
+    logger.info(f"预计需要进行 {len(land_centers_df)} 次最近邻搜索,使用高性能算法处理...")
 
     # 读取Excel采样点数据
     if not os.path.exists(xls_file):
@@ -128,26 +240,20 @@ def process_land_data(land_type, coefficient_params=None, save_csv=True):
     Num = param1 * param2
     logger.info(f"系数: {param1} * {param2} = {Num}")
 
-    # 处理每个面要素,使用数据库中的中心点坐标
-    cd_values = []
-    centers = []
+    # 高效处理:使用空间索引查找最近采样点
+    logger.info("开始高效距离计算和Cd值计算...")
+    start_time = time()
     
-    for index, row in land_centers_df.iterrows():
-        center_lon = row['center_lon']
-        center_lat = row['center_lat']
-        centers.append((center_lon, center_lat))
-
-        # 计算到所有采样点的距离
-        distances = df_xls.apply(
-            lambda x: Point(center_lon, center_lat).distance(Point(x['经度'], x['纬度'])),
-            axis=1
-        )
-        min_idx = distances.idxmin()
-        nearest = df_xls.loc[min_idx]
-
-        # 计算Cd含量值
-        cd_value = nearest['Cd (ug/L)'] * Num
-        cd_values.append(cd_value)
+    # 使用优化的空间索引方法查找最近采样点
+    nearest_indices = find_nearest_sampling_points_optimized(land_centers_df, df_xls)
+    
+    # 批量计算Cd含量值
+    centers = list(zip(land_centers_df['center_lon'], land_centers_df['center_lat']))
+    cd_values = df_xls.iloc[nearest_indices]['Cd (ug/L)'].values * Num
+    
+    calculation_time = time() - start_time
+    logger.info(f"Cd值计算完成,耗时: {calculation_time:.2f}秒")
+    logger.info(f"处理了 {len(centers)} 个土地中心点")
 
     # 创建简化数据DataFrame
     simplified_data = pd.DataFrame({
@@ -297,7 +403,8 @@ def process_land_to_visualization(land_type, coefficient_params=None,
                                   enable_interpolation: Optional[bool] = True,
                                   interpolation_method: Optional[str] = "linear",
                                   resolution_factor: Optional[float] = 4.0,
-                                  save_csv: Optional[bool] = True):
+                                  save_csv: Optional[bool] = True,
+                                  cleanup_temp_files: Optional[bool] = True):
     """
     完整的土地数据处理可视化流程(使用统一的MappingUtils接口,支持动态边界和插值控制)
     
@@ -321,6 +428,7 @@ def process_land_to_visualization(land_type, coefficient_params=None,
     @param interpolation_method: 插值方法,nearest | linear | cubic,默认linear
     @param resolution_factor: 分辨率因子,默认4.0,越大分辨率越高
     @param save_csv: 是否生成CSV文件,默认True
+    @param cleanup_temp_files: 是否清理临时文件,默认True
     @returns: 包含所有生成文件路径的元组
     """
     base_dir = get_base_dir()
@@ -441,6 +549,38 @@ def process_land_to_visualization(land_type, coefficient_params=None,
         data_dir = os.path.join(base_dir, "..", "static", "water", "Data")
         cleaned_csv = os.path.join(data_dir, f"中心点经纬度与预测值&{land_type}_清洗.csv")
     
+    # 清理临时文件(如果启用)
+    if cleanup_temp_files:
+        logger.info("开始清理临时文件...")
+        
+        # 要清理的临时文件列表
+        temp_files_to_cleanup = []
+        
+        # 添加临时栅格文件(如果是memory_raster_开头的)
+        if output_tif and 'memory_raster_' in os.path.basename(output_tif):
+            temp_files_to_cleanup.append(output_tif)
+            
+        # 添加临时shapefile(如果存在且是临时文件)
+        temp_shapefile = workflow_result.get('shapefile')
+        if temp_shapefile and ('temp' in temp_shapefile.lower() or 'memory' in temp_shapefile.lower()):
+            temp_files_to_cleanup.append(temp_shapefile)
+        
+        # 如果不保存CSV,也清理CSV文件
+        if not save_csv and cleaned_csv_path and os.path.exists(cleaned_csv_path):
+            temp_files_to_cleanup.append(cleaned_csv_path)
+            
+        # 执行清理
+        if temp_files_to_cleanup:
+            cleanup_temporary_files(*temp_files_to_cleanup)
+            logger.info(f"已清理 {len(temp_files_to_cleanup)} 个临时文件")
+            
+            # 如果清理了栅格文件,将返回路径设为None以避免引用已删除的文件
+            if output_tif in temp_files_to_cleanup:
+                output_tif = None
+                logger.info("注意:临时栅格文件已被清理,返回的栅格路径为None")
+        else:
+            logger.info("没有临时文件需要清理")
+    
     return cleaned_csv, workflow_result['shapefile'], output_tif, map_output, hist_output, used_coeff
 
 
@@ -544,6 +684,13 @@ def main():
     except Exception as e:
         logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
     finally:
+        # 清理临时文件
+        base_dir = get_base_dir()
+        raster_dir = os.path.join(base_dir, "..", "static", "water", "Raster")
+        cleaned_count = cleanup_temp_files_in_directory(raster_dir)
+        if cleaned_count > 0:
+            logger.info(f"已清理 {cleaned_count} 个临时文件")
+        
         logger.info("处理完成")