""" 通用绘图工具模块 Universal Mapping and Visualization Utils 整合了CSV转GeoTIFF、栅格地图绘制和直方图生成功能 基于01_Transfer_csv_to_geotif.py和02_Figure_raster_mapping.py的更新代码 Author: Integrated from Wanxue Zhu's code """ import pandas as pd import geopandas as gpd from shapely.geometry import Point import rasterio from rasterio.features import rasterize from rasterio.transform import from_origin from rasterio.mask import mask from rasterio.plot import show import numpy as np import os import json import logging from scipy.interpolate import griddata from scipy.ndimage import distance_transform_edt import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap, BoundaryNorm import seaborn as sns import warnings warnings.filterwarnings('ignore') # 配置日志 logger = logging.getLogger(__name__) # 设置matplotlib的中文字体和样式 plt.rcParams['font.family'] = 'Arial' plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] # 添加多个中文字体 # 预定义的色彩方案 COLORMAPS = { 'yellow_orange_brown': ['#FFFECE', '#FFF085', '#FEBA17', '#BE3D2A', '#74512D', '#4E1F00'], # 黄-橙-棕 'blue_series': ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60', '#2A3335'], # 蓝色系 'yellow_green': ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'], # 淡黄-草绿 'green_brown': ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'], # 绿色-棕色 'yellow_pink_purple': ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'], # 黄-粉-紫 'green_yellow_red_purple': ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F'], # 绿-黄-红-紫 } class MappingUtils: """ 通用绘图工具类 提供CSV转换、栅格处理、地图绘制和直方图生成功能 """ def __init__(self, log_level=logging.INFO): """ 初始化绘图工具 @param log_level: 日志级别 """ self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(log_level) if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) def csv_to_shapefile(self, csv_file, shapefile_output, lon_col=0, lat_col=1, value_col=2): """ 将CSV文件转换为Shapefile文件 @param csv_file: CSV文件路径 @param shapefile_output: 输出Shapefile文件路径 @param lon_col: 经度列索引或列名,默认第0列 @param lat_col: 纬度列索引或列名,默认第1列 @param value_col: 数值列索引或列名,默认第2列 @return: 输出的shapefile路径 """ try: self.logger.info(f"开始转换CSV到Shapefile: {csv_file}") # 读取CSV数据 df = pd.read_csv(csv_file) # 支持列索引或列名 if isinstance(lon_col, int): lon = df.iloc[:, lon_col] else: lon = df[lon_col] if isinstance(lat_col, int): lat = df.iloc[:, lat_col] else: lat = df[lat_col] if isinstance(value_col, int): val = df.iloc[:, value_col] else: val = df[value_col] # 创建几何对象 geometry = [Point(xy) for xy in zip(lon, lat)] gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326") # 确保输出目录存在 os.makedirs(os.path.dirname(shapefile_output), exist_ok=True) # 保存Shapefile gdf.to_file(shapefile_output, driver="ESRI Shapefile") self.logger.info(f"✓ 成功转换CSV到Shapefile: {shapefile_output}") return shapefile_output except Exception as e: self.logger.error(f"CSV转Shapefile失败: {str(e)}") raise def create_boundary_mask(self, raster, transform, gdf): """ 创建边界掩膜,只保留边界内的区域 @param raster: 栅格数据 @param transform: 栅格变换参数 @param gdf: 矢量边界数据 @return: 边界掩膜 """ try: mask = rasterize( gdf.geometry, out_shape=raster.shape, transform=transform, fill=0, default_value=1, dtype=np.uint8 ) return mask.astype(bool) except Exception as e: self.logger.error(f"创建边界掩膜失败: {str(e)}") raise def interpolate_nan_values(self, raster, method='nearest'): """ 使用插值方法填充NaN值 @param raster: 包含NaN值的栅格数据 @param method: 插值方法 ('nearest', 'linear', 'cubic') @return: 插值后的栅格数据 """ try: if not np.isnan(raster).any(): return raster # 获取有效值的坐标 valid_mask = ~np.isnan(raster) valid_coords = np.where(valid_mask) valid_values = raster[valid_mask] if len(valid_values) == 0: self.logger.warning("没有有效值用于插值") return raster # 创建网格坐标 rows, cols = raster.shape grid_x, grid_y = np.mgrid[0:rows, 0:cols] # 准备插值坐标 points = np.column_stack((valid_coords[0], valid_coords[1])) # 执行插值 interpolated = griddata(points, valid_values, (grid_x, grid_y), method=method, fill_value=np.nan) # 如果插值后仍有NaN值,使用最近邻方法填充 if np.isnan(interpolated).any(): self.logger.info(f"使用 {method} 插值后仍有NaN值,使用最近邻方法填充剩余值") remaining_nan = np.isnan(interpolated) remaining_coords = np.where(remaining_nan) if len(remaining_coords[0]) > 0: # 使用距离变换找到最近的已知值 dist, indices = distance_transform_edt(remaining_nan, return_distances=True, return_indices=True) # 填充剩余的NaN值 for i, j in zip(remaining_coords[0], remaining_coords[1]): if indices[0, i, j] < rows and indices[1, i, j] < cols: interpolated[i, j] = raster[indices[0, i, j], indices[1, i, j]] return interpolated except Exception as e: self.logger.error(f"插值失败: {str(e)}") return raster def vector_to_raster(self, input_shapefile, template_tif, output_tif, field, resolution_factor=16.0, boundary_shp=None, interpolation_method='nearest', enable_interpolation=True): """ 将点矢量数据转换为栅格数据 @param input_shapefile: 输入点矢量数据的Shapefile文件路径 @param template_tif: 用作模板的GeoTIFF文件路径 @param output_tif: 输出栅格化后的GeoTIFF文件路径 @param field: 用于栅格化的属性字段名 @param resolution_factor: 分辨率倍数因子 @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜 @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic') @param enable_interpolation: 是否启用空间插值,默认True @return: 输出的GeoTIFF文件路径和统计信息 """ try: self.logger.info(f"开始处理: {input_shapefile}") self.logger.info(f"分辨率因子: {resolution_factor}, 插值方法: {interpolation_method}") # 读取矢量数据 gdf = gpd.read_file(input_shapefile) # 读取模板栅格 with rasterio.open(template_tif) as src: template_meta = src.meta.copy() # 根据分辨率因子计算新的尺寸和变换参数 if resolution_factor != 1.0: width = int(src.width * resolution_factor) height = int(src.height * resolution_factor) transform = rasterio.Affine( src.transform.a / resolution_factor, src.transform.b, src.transform.c, src.transform.d, src.transform.e / resolution_factor, src.transform.f ) self.logger.info(f"分辨率调整: {src.width}x{src.height} -> {width}x{height}") else: width = src.width height = src.height transform = src.transform self.logger.info(f"保持原始分辨率: {width}x{height}") crs = src.crs # 投影矢量数据 if gdf.crs != crs: gdf = gdf.to_crs(crs) # 栅格化 shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[field])) raster = rasterize( shapes=shapes, out_shape=(height, width), transform=transform, fill=np.nan, dtype='float32' ) # 应用边界掩膜(如果提供) if boundary_shp and os.path.exists(boundary_shp): self.logger.info(f"应用边界掩膜: {boundary_shp}") boundary_gdf = gpd.read_file(boundary_shp) if boundary_gdf.crs != crs: boundary_gdf = boundary_gdf.to_crs(crs) boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf) raster[~boundary_mask] = np.nan # 使用插值方法填充NaN值(如果启用) if enable_interpolation and np.isnan(raster).any(): self.logger.info(f"使用 {interpolation_method} 方法进行插值...") raster = self.interpolate_nan_values(raster, method=interpolation_method) elif not enable_interpolation and np.isnan(raster).any(): self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)") # 创建输出目录 os.makedirs(os.path.dirname(output_tif), exist_ok=True) # 更新元数据 template_meta.update({ "count": 1, "dtype": 'float32', "nodata": np.nan, "width": width, "height": height, "transform": transform }) # 保存栅格文件 with rasterio.open(output_tif, 'w', **template_meta) as dst: dst.write(raster, 1) # 计算统计信息 valid_data = raster[~np.isnan(raster)] stats = None if len(valid_data) > 0: stats = { 'min': float(np.min(valid_data)), 'max': float(np.max(valid_data)), 'mean': float(np.mean(valid_data)), 'std': float(np.std(valid_data)), 'valid_pixels': int(len(valid_data)), 'total_pixels': int(raster.size) } self.logger.info(f"统计信息: 有效像素 {stats['valid_pixels']}/{stats['total_pixels']}") self.logger.info(f"数值范围: {stats['min']:.4f} - {stats['max']:.4f}") else: self.logger.warning("没有有效数据") self.logger.info(f"✓ 成功保存: {output_tif}") return output_tif, stats except Exception as e: self.logger.error(f"矢量转栅格失败: {str(e)}") raise def create_raster_map(self, shp_path, tif_path, output_path, colormap='green_yellow_red_purple', title="Prediction Map", output_size=12, figsize=None, dpi=300, resolution_factor=1.0, enable_interpolation=False, interpolation_method='nearest'): """ 创建栅格地图 @param shp_path: 输入的矢量数据路径 @param tif_path: 输入的栅格数据路径 @param output_path: 输出图片路径(不包含扩展名) @param colormap: 色彩方案名称或颜色列表 @param title: 图片标题 @param output_size: 图片尺寸(正方形),如果指定了figsize则忽略此参数 @param figsize: 图片尺寸元组 (width, height),优先级高于output_size @param dpi: 图片分辨率 @param resolution_factor: 分辨率因子,>1提高分辨率,<1降低分辨率 @param enable_interpolation: 是否启用空间插值,用于处理NaN值或提高分辨率 @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic') @param enable_interpolation: 是否启用空间插值,默认True @return: 输出图片文件路径 """ try: self.logger.info(f"开始创建栅格地图: {tif_path}") self.logger.info(f"分辨率因子: {resolution_factor}, 启用插值: {enable_interpolation}") # 读取矢量边界 gdf = gpd.read_file(shp_path) if shp_path else None # 读取并裁剪栅格数据 with rasterio.open(tif_path) as src: original_transform = src.transform original_crs = src.crs if gdf is not None: # 确保坐标系一致 if gdf.crs != src.crs: gdf = gdf.to_crs(src.crs) # 裁剪栅格 geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]] out_image, out_transform = mask(src, geoms, crop=True) out_meta = src.meta.copy() else: # 如果没有边界文件,使用整个栅格 out_image = src.read() out_transform = src.transform out_meta = src.meta.copy() # 提取数据并处理无效值 raster = out_image[0].astype('float32') nodata = out_meta.get("nodata", None) if nodata is not None: raster[raster == nodata] = np.nan # 应用分辨率因子重采样 if resolution_factor != 1.0: self.logger.info(f"应用分辨率因子重采样: {resolution_factor}") raster, out_transform = self._resample_raster(raster, out_transform, resolution_factor, original_crs) # 应用空间插值(如果启用) if enable_interpolation and np.isnan(raster).any(): self.logger.info(f"使用 {interpolation_method} 方法进行空间插值") raster = self.interpolate_nan_values(raster, method=interpolation_method) # 检查是否有有效数据 if np.all(np.isnan(raster)): raise ValueError("栅格数据中没有有效值") # 根据分位数分为6个等级 bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100]) norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1) # 获取色彩方案 if isinstance(colormap, str): if colormap in COLORMAPS: color_list = COLORMAPS[colormap] else: self.logger.warning(f"未知色彩方案: {colormap},使用默认方案") color_list = COLORMAPS['green_yellow_red_purple'] else: color_list = colormap cmap = ListedColormap(color_list) # 设置图片尺寸 if figsize is not None: fig_size = figsize else: fig_size = (output_size, output_size) # 绘图 fig, ax = plt.subplots(figsize=fig_size) show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm) # 添加矢量边界 if gdf is not None: gdf.boundary.plot(ax=ax, color='black', linewidth=1) # 设置标题和标签 ax.set_title(title, fontsize=20) ax.set_xlabel("Longitude", fontsize=18) ax.set_ylabel("Latitude", fontsize=18) ax.grid(True, linestyle='--', color='gray', alpha=0.5) ax.tick_params(axis='y', labelrotation=90) # 添加色带 tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)] cbar = plt.colorbar( plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)], shrink=0.6, aspect=15 ) cbar.ax.set_yticklabels(tick_labels) cbar.set_label("Values") plt.tight_layout() # 确保输出目录存在 os.makedirs(os.path.dirname(output_path), exist_ok=True) # 保存图片 output_file = f"{output_path}.jpg" plt.savefig(output_file, dpi=dpi, format='jpg', bbox_inches='tight') plt.close() self.logger.info(f"✓ 栅格地图创建成功: {output_file}") return output_file except Exception as e: self.logger.error(f"栅格地图创建失败: {str(e)}") raise def create_histogram(self, file_path, save_path=None, figsize=(10, 6), xlabel='像元值', ylabel='频率密度', title='数值分布图', bins=100, dpi=300): """ 绘制GeoTIFF文件的直方图 @param file_path: GeoTIFF文件路径 @param save_path: 保存路径,如果为None则自动生成 @param figsize: 图像尺寸 @param xlabel: 横坐标标签 @param ylabel: 纵坐标标签 @param title: 图标题 @param bins: 直方图箱数 @param dpi: 图片分辨率 @return: 输出图片文件路径 """ try: self.logger.info(f"开始创建直方图: {file_path}") # 设置seaborn样式 sns.set(style='ticks') # 读取栅格数据 with rasterio.open(file_path) as src: band = src.read(1) nodata = src.nodata # 处理无效值 if nodata is not None: band = np.where(band == nodata, np.nan, band) # 展平数据并移除NaN值 band_flat = band.flatten() band_flat = band_flat[~np.isnan(band_flat)] if len(band_flat) == 0: raise ValueError("栅格数据中没有有效值") # 创建图形 plt.figure(figsize=figsize) # 绘制直方图和密度曲线 sns.histplot(band_flat, bins=bins, color='steelblue', alpha=0.7, edgecolor='black', stat='density') sns.kdeplot(band_flat, color='red', linewidth=2) # 设置标签和标题 plt.xlabel(xlabel, fontsize=14) plt.ylabel(ylabel, fontsize=14) plt.title(title, fontsize=16) plt.grid(True, linestyle='--', alpha=0.5) plt.tight_layout() # 保存图片 if save_path is None: save_path = file_path.replace('.tif', '_histogram.jpg') # 确保输出目录存在 os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=dpi, format='jpg', bbox_inches='tight') plt.close() self.logger.info(f"✓ 直方图创建成功: {save_path}") return save_path except Exception as e: self.logger.error(f"直方图创建失败: {str(e)}") raise def get_available_colormaps(): """ 获取可用的色彩方案列表 @return: 色彩方案字典 """ return COLORMAPS.copy() def csv_to_raster_workflow(csv_file, template_tif, output_dir, boundary_shp=None, resolution_factor=16.0, interpolation_method='nearest', field_name='Prediction', lon_col=0, lat_col=1, value_col=2, enable_interpolation=False): """ 完整的CSV到栅格转换工作流 @param csv_file: CSV文件路径 @param template_tif: 模板GeoTIFF文件路径 @param output_dir: 输出目录 @param boundary_shp: 边界Shapefile文件路径(可选) @param resolution_factor: 分辨率因子 @param interpolation_method: 插值方法 @param field_name: 字段名称 @param lon_col: 经度列 @param lat_col: 纬度列 @param value_col: 数值列 @param enable_interpolation: 是否启用空间插值,默认False @return: 输出文件路径字典 """ mapper = MappingUtils() # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) # 生成文件名 base_name = os.path.splitext(os.path.basename(csv_file))[0] shapefile_path = os.path.join(output_dir, f"{base_name}_points.shp") raster_path = os.path.join(output_dir, f"{base_name}_raster.tif") # 1. CSV转Shapefile mapper.csv_to_shapefile(csv_file, shapefile_path, lon_col, lat_col, value_col) # 2. Shapefile转栅格 raster_path, stats = mapper.vector_to_raster( shapefile_path, template_tif, raster_path, field_name, resolution_factor, boundary_shp, interpolation_method, enable_interpolation ) return { 'shapefile': shapefile_path, 'raster': raster_path, 'statistics': stats }