|
- """
- 通用绘图工具模块
- Universal Mapping and Visualization Utils
- 整合了CSV转GeoTIFF、栅格地图绘制和直方图生成功能
- 基于01_Transfer_csv_to_geotif.py和02_Figure_raster_mapping.py的更新代码
- Author: Integrated from Wanxue Zhu's code
- """
- import pandas as pd
- import geopandas as gpd
- from shapely.geometry import Point
- import rasterio
- from rasterio.features import rasterize
- from rasterio.transform import from_origin
- from rasterio.mask import mask
- from rasterio.plot import show
- import numpy as np
- import os
- import json
- import logging
- from scipy.interpolate import griddata
- from scipy.ndimage import distance_transform_edt
- import matplotlib.pyplot as plt
- from matplotlib.colors import ListedColormap, BoundaryNorm
- import seaborn as sns
- import warnings
- warnings.filterwarnings('ignore')
- # 配置日志
- logger = logging.getLogger(__name__)
- # 设置matplotlib的中文字体和样式
- plt.rcParams['font.family'] = 'Arial'
- plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
- plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] # 添加多个中文字体
- # 预定义的色彩方案
- COLORMAPS = {
- 'yellow_orange_brown': ['#FFFECE', '#FFF085', '#FEBA17', '#BE3D2A', '#74512D', '#4E1F00'], # 黄-橙-棕
- 'blue_series': ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60', '#2A3335'], # 蓝色系
- 'yellow_green': ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'], # 淡黄-草绿
- 'green_brown': ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'], # 绿色-棕色
- 'yellow_pink_purple': ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'], # 黄-粉-紫
- 'green_yellow_red_purple': ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F'], # 绿-黄-红-紫
- }
- class MappingUtils:
- """
- 通用绘图工具类
- 提供CSV转换、栅格处理、地图绘制和直方图生成功能
- """
-
- def __init__(self, log_level=logging.INFO):
- """
- 初始化绘图工具
-
- @param log_level: 日志级别
- """
- self.logger = logging.getLogger(self.__class__.__name__)
- self.logger.setLevel(log_level)
-
- if not self.logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- self.logger.addHandler(handler)
-
- def csv_to_shapefile(self, csv_file, shapefile_output, lon_col=0, lat_col=1, value_col=2):
- """
- 将CSV文件转换为Shapefile文件
-
- @param csv_file: CSV文件路径
- @param shapefile_output: 输出Shapefile文件路径
- @param lon_col: 经度列索引或列名,默认第0列
- @param lat_col: 纬度列索引或列名,默认第1列
- @param value_col: 数值列索引或列名,默认第2列
- @return: 输出的shapefile路径
- """
- try:
- self.logger.info(f"开始转换CSV到Shapefile: {csv_file}")
-
- # 读取CSV数据
- df = pd.read_csv(csv_file)
-
- # 支持列索引或列名
- if isinstance(lon_col, int):
- lon = df.iloc[:, lon_col]
- else:
- lon = df[lon_col]
-
- if isinstance(lat_col, int):
- lat = df.iloc[:, lat_col]
- else:
- lat = df[lat_col]
-
- if isinstance(value_col, int):
- val = df.iloc[:, value_col]
- else:
- val = df[value_col]
-
- # 创建几何对象
- geometry = [Point(xy) for xy in zip(lon, lat)]
- gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(shapefile_output), exist_ok=True)
-
- # 保存Shapefile
- gdf.to_file(shapefile_output, driver="ESRI Shapefile")
- self.logger.info(f"✓ 成功转换CSV到Shapefile: {shapefile_output}")
-
- return shapefile_output
-
- except Exception as e:
- self.logger.error(f"CSV转Shapefile失败: {str(e)}")
- raise
-
- def create_boundary_mask(self, raster, transform, gdf):
- """
- 创建边界掩膜,只保留边界内的区域
-
- @param raster: 栅格数据
- @param transform: 栅格变换参数
- @param gdf: 矢量边界数据
- @return: 边界掩膜
- """
- try:
- mask = rasterize(
- gdf.geometry,
- out_shape=raster.shape,
- transform=transform,
- fill=0,
- default_value=1,
- dtype=np.uint8
- )
- return mask.astype(bool)
- except Exception as e:
- self.logger.error(f"创建边界掩膜失败: {str(e)}")
- raise
-
- def interpolate_nan_values(self, raster, method='nearest'):
- """
- 使用插值方法填充NaN值
-
- @param raster: 包含NaN值的栅格数据
- @param method: 插值方法 ('nearest', 'linear', 'cubic')
- @return: 插值后的栅格数据
- """
- try:
- if not np.isnan(raster).any():
- return raster
-
- # 获取有效值的坐标
- valid_mask = ~np.isnan(raster)
- valid_coords = np.where(valid_mask)
- valid_values = raster[valid_mask]
-
- if len(valid_values) == 0:
- self.logger.warning("没有有效值用于插值")
- return raster
-
- # 创建网格坐标
- rows, cols = raster.shape
- grid_x, grid_y = np.mgrid[0:rows, 0:cols]
-
- # 准备插值坐标
- points = np.column_stack((valid_coords[0], valid_coords[1]))
-
- # 执行插值
- interpolated = griddata(points, valid_values, (grid_x, grid_y),
- method=method, fill_value=np.nan)
-
- # 如果插值后仍有NaN值,使用最近邻方法填充
- if np.isnan(interpolated).any():
- self.logger.info(f"使用 {method} 插值后仍有NaN值,使用最近邻方法填充剩余值")
- remaining_nan = np.isnan(interpolated)
- remaining_coords = np.where(remaining_nan)
-
- if len(remaining_coords[0]) > 0:
- # 使用距离变换找到最近的已知值
- dist, indices = distance_transform_edt(remaining_nan,
- return_distances=True,
- return_indices=True)
-
- # 填充剩余的NaN值
- for i, j in zip(remaining_coords[0], remaining_coords[1]):
- if indices[0, i, j] < rows and indices[1, i, j] < cols:
- interpolated[i, j] = raster[indices[0, i, j], indices[1, i, j]]
-
- return interpolated
-
- except Exception as e:
- self.logger.error(f"插值失败: {str(e)}")
- return raster
-
- def vector_to_raster(self, input_shapefile, template_tif, output_tif, field,
- resolution_factor=16.0, boundary_shp=None, interpolation_method='nearest', enable_interpolation=True):
- """
- 将点矢量数据转换为栅格数据
-
- @param input_shapefile: 输入点矢量数据的Shapefile文件路径
- @param template_tif: 用作模板的GeoTIFF文件路径
- @param output_tif: 输出栅格化后的GeoTIFF文件路径
- @param field: 用于栅格化的属性字段名
- @param resolution_factor: 分辨率倍数因子
- @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜
- @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
- @param enable_interpolation: 是否启用空间插值,默认True
- @return: 输出的GeoTIFF文件路径和统计信息
- """
- try:
- self.logger.info(f"开始处理: {input_shapefile}")
- self.logger.info(f"分辨率因子: {resolution_factor}, 插值方法: {interpolation_method}")
-
- # 读取矢量数据
- gdf = gpd.read_file(input_shapefile)
-
- # 读取模板栅格
- with rasterio.open(template_tif) as src:
- template_meta = src.meta.copy()
-
- # 根据分辨率因子计算新的尺寸和变换参数
- if resolution_factor != 1.0:
- width = int(src.width * resolution_factor)
- height = int(src.height * resolution_factor)
-
- transform = rasterio.Affine(
- src.transform.a / resolution_factor,
- src.transform.b,
- src.transform.c,
- src.transform.d,
- src.transform.e / resolution_factor,
- src.transform.f
- )
- self.logger.info(f"分辨率调整: {src.width}x{src.height} -> {width}x{height}")
- else:
- width = src.width
- height = src.height
- transform = src.transform
- self.logger.info(f"保持原始分辨率: {width}x{height}")
-
- crs = src.crs
-
- # 投影矢量数据
- if gdf.crs != crs:
- gdf = gdf.to_crs(crs)
-
- # 栅格化
- shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[field]))
- raster = rasterize(
- shapes=shapes,
- out_shape=(height, width),
- transform=transform,
- fill=np.nan,
- dtype='float32'
- )
-
- # 应用边界掩膜(如果提供)
- if boundary_shp and os.path.exists(boundary_shp):
- self.logger.info(f"应用边界掩膜: {boundary_shp}")
- boundary_gdf = gpd.read_file(boundary_shp)
- if boundary_gdf.crs != crs:
- boundary_gdf = boundary_gdf.to_crs(crs)
- boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
- raster[~boundary_mask] = np.nan
-
- # 使用插值方法填充NaN值(如果启用)
- if enable_interpolation and np.isnan(raster).any():
- self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
- raster = self.interpolate_nan_values(raster, method=interpolation_method)
- elif not enable_interpolation and np.isnan(raster).any():
- self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
-
- # 创建输出目录
- os.makedirs(os.path.dirname(output_tif), exist_ok=True)
-
- # 更新元数据
- template_meta.update({
- "count": 1,
- "dtype": 'float32',
- "nodata": np.nan,
- "width": width,
- "height": height,
- "transform": transform
- })
-
- # 保存栅格文件
- with rasterio.open(output_tif, 'w', **template_meta) as dst:
- dst.write(raster, 1)
-
- # 计算统计信息
- valid_data = raster[~np.isnan(raster)]
- stats = None
- if len(valid_data) > 0:
- stats = {
- 'min': float(np.min(valid_data)),
- 'max': float(np.max(valid_data)),
- 'mean': float(np.mean(valid_data)),
- 'std': float(np.std(valid_data)),
- 'valid_pixels': int(len(valid_data)),
- 'total_pixels': int(raster.size)
- }
- self.logger.info(f"统计信息: 有效像素 {stats['valid_pixels']}/{stats['total_pixels']}")
- self.logger.info(f"数值范围: {stats['min']:.4f} - {stats['max']:.4f}")
- else:
- self.logger.warning("没有有效数据")
-
- self.logger.info(f"✓ 成功保存: {output_tif}")
- return output_tif, stats
-
- except Exception as e:
- self.logger.error(f"矢量转栅格失败: {str(e)}")
- raise
-
- def create_raster_map(self, shp_path, tif_path, output_path,
- colormap='green_yellow_red_purple', title="Prediction Map",
- output_size=12, figsize=None, dpi=300,
- resolution_factor=1.0, enable_interpolation=False,
- interpolation_method='nearest'):
- """
- 创建栅格地图
-
- @param shp_path: 输入的矢量数据路径
- @param tif_path: 输入的栅格数据路径
- @param output_path: 输出图片路径(不包含扩展名)
- @param colormap: 色彩方案名称或颜色列表
- @param title: 图片标题
- @param output_size: 图片尺寸(正方形),如果指定了figsize则忽略此参数
- @param figsize: 图片尺寸元组 (width, height),优先级高于output_size
- @param dpi: 图片分辨率
- @param resolution_factor: 分辨率因子,>1提高分辨率,<1降低分辨率
- @param enable_interpolation: 是否启用空间插值,用于处理NaN值或提高分辨率
- @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
- @param enable_interpolation: 是否启用空间插值,默认True
- @return: 输出图片文件路径
- """
- try:
- self.logger.info(f"开始创建栅格地图: {tif_path}")
- self.logger.info(f"分辨率因子: {resolution_factor}, 启用插值: {enable_interpolation}")
-
- # 读取矢量边界
- gdf = gpd.read_file(shp_path) if shp_path else None
-
- # 读取并裁剪栅格数据
- with rasterio.open(tif_path) as src:
- original_transform = src.transform
- original_crs = src.crs
-
- if gdf is not None:
- # 确保坐标系一致
- if gdf.crs != src.crs:
- gdf = gdf.to_crs(src.crs)
-
- # 裁剪栅格
- geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]]
- out_image, out_transform = mask(src, geoms, crop=True)
- out_meta = src.meta.copy()
- else:
- # 如果没有边界文件,使用整个栅格
- out_image = src.read()
- out_transform = src.transform
- out_meta = src.meta.copy()
-
- # 提取数据并处理无效值
- raster = out_image[0].astype('float32')
- nodata = out_meta.get("nodata", None)
- if nodata is not None:
- raster[raster == nodata] = np.nan
-
- # 应用分辨率因子重采样
- if resolution_factor != 1.0:
- self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
- raster, out_transform = self._resample_raster(raster, out_transform, resolution_factor, original_crs)
-
- # 应用空间插值(如果启用)
- if enable_interpolation and np.isnan(raster).any():
- self.logger.info(f"使用 {interpolation_method} 方法进行空间插值")
- raster = self.interpolate_nan_values(raster, method=interpolation_method)
-
- # 检查是否有有效数据
- if np.all(np.isnan(raster)):
- raise ValueError("栅格数据中没有有效值")
-
- # 根据分位数分为6个等级
- bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100])
- norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1)
-
- # 获取色彩方案
- if isinstance(colormap, str):
- if colormap in COLORMAPS:
- color_list = COLORMAPS[colormap]
- else:
- self.logger.warning(f"未知色彩方案: {colormap},使用默认方案")
- color_list = COLORMAPS['green_yellow_red_purple']
- else:
- color_list = colormap
-
- cmap = ListedColormap(color_list)
-
- # 设置图片尺寸
- if figsize is not None:
- fig_size = figsize
- else:
- fig_size = (output_size, output_size)
-
- # 绘图
- fig, ax = plt.subplots(figsize=fig_size)
- show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
-
- # 添加矢量边界
- if gdf is not None:
- gdf.boundary.plot(ax=ax, color='black', linewidth=1)
-
- # 设置标题和标签
- ax.set_title(title, fontsize=20)
- ax.set_xlabel("Longitude", fontsize=18)
- ax.set_ylabel("Latitude", fontsize=18)
- ax.grid(True, linestyle='--', color='gray', alpha=0.5)
- ax.tick_params(axis='y', labelrotation=90)
-
- # 添加色带
- tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)]
- cbar = plt.colorbar(
- plt.cm.ScalarMappable(norm=norm, cmap=cmap),
- ax=ax,
- ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)],
- shrink=0.6,
- aspect=15
- )
- cbar.ax.set_yticklabels(tick_labels)
- cbar.set_label("Values")
-
- plt.tight_layout()
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
-
- # 保存图片
- output_file = f"{output_path}.jpg"
- plt.savefig(output_file, dpi=dpi, format='jpg', bbox_inches='tight')
- plt.close()
-
- self.logger.info(f"✓ 栅格地图创建成功: {output_file}")
- return output_file
-
- except Exception as e:
- self.logger.error(f"栅格地图创建失败: {str(e)}")
- raise
-
- def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
- xlabel='像元值', ylabel='频率密度', title='数值分布图',
- bins=100, dpi=300):
- """
- 绘制GeoTIFF文件的直方图
-
- @param file_path: GeoTIFF文件路径
- @param save_path: 保存路径,如果为None则自动生成
- @param figsize: 图像尺寸
- @param xlabel: 横坐标标签
- @param ylabel: 纵坐标标签
- @param title: 图标题
- @param bins: 直方图箱数
- @param dpi: 图片分辨率
- @return: 输出图片文件路径
- """
- try:
- self.logger.info(f"开始创建直方图: {file_path}")
-
- # 设置seaborn样式
- sns.set(style='ticks')
-
- # 读取栅格数据
- with rasterio.open(file_path) as src:
- band = src.read(1)
- nodata = src.nodata
-
- # 处理无效值
- if nodata is not None:
- band = np.where(band == nodata, np.nan, band)
-
- # 展平数据并移除NaN值
- band_flat = band.flatten()
- band_flat = band_flat[~np.isnan(band_flat)]
-
- if len(band_flat) == 0:
- raise ValueError("栅格数据中没有有效值")
-
- # 创建图形
- plt.figure(figsize=figsize)
-
- # 绘制直方图和密度曲线
- sns.histplot(band_flat, bins=bins, color='steelblue', alpha=0.7,
- edgecolor='black', stat='density')
- sns.kdeplot(band_flat, color='red', linewidth=2)
-
- # 设置标签和标题
- plt.xlabel(xlabel, fontsize=14)
- plt.ylabel(ylabel, fontsize=14)
- plt.title(title, fontsize=16)
- plt.grid(True, linestyle='--', alpha=0.5)
- plt.tight_layout()
-
- # 保存图片
- if save_path is None:
- save_path = file_path.replace('.tif', '_histogram.jpg')
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
-
- plt.savefig(save_path, dpi=dpi, format='jpg', bbox_inches='tight')
- plt.close()
-
- self.logger.info(f"✓ 直方图创建成功: {save_path}")
- return save_path
-
- except Exception as e:
- self.logger.error(f"直方图创建失败: {str(e)}")
- raise
- def get_available_colormaps():
- """
- 获取可用的色彩方案列表
-
- @return: 色彩方案字典
- """
- return COLORMAPS.copy()
- def csv_to_raster_workflow(csv_file, template_tif, output_dir,
- boundary_shp=None, resolution_factor=16.0,
- interpolation_method='nearest', field_name='Prediction',
- lon_col=0, lat_col=1, value_col=2, enable_interpolation=False):
- """
- 完整的CSV到栅格转换工作流
-
- @param csv_file: CSV文件路径
- @param template_tif: 模板GeoTIFF文件路径
- @param output_dir: 输出目录
- @param boundary_shp: 边界Shapefile文件路径(可选)
- @param resolution_factor: 分辨率因子
- @param interpolation_method: 插值方法
- @param field_name: 字段名称
- @param lon_col: 经度列
- @param lat_col: 纬度列
- @param value_col: 数值列
- @param enable_interpolation: 是否启用空间插值,默认False
- @return: 输出文件路径字典
- """
- mapper = MappingUtils()
-
- # 确保输出目录存在
- os.makedirs(output_dir, exist_ok=True)
-
- # 生成文件名
- base_name = os.path.splitext(os.path.basename(csv_file))[0]
- shapefile_path = os.path.join(output_dir, f"{base_name}_points.shp")
- raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
-
- # 1. CSV转Shapefile
- mapper.csv_to_shapefile(csv_file, shapefile_path, lon_col, lat_col, value_col)
-
- # 2. Shapefile转栅格
- raster_path, stats = mapper.vector_to_raster(
- shapefile_path, template_tif, raster_path, field_name,
- resolution_factor, boundary_shp, interpolation_method, enable_interpolation
- )
-
- return {
- 'shapefile': shapefile_path,
- 'raster': raster_path,
- 'statistics': stats
- }
|