123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- """
- 可视化模块
- Visualization Module
- 基于原始01_Figure_raster_mapping.py改进,用于生成栅格地图和直方图
- """
- import os
- import sys
- import logging
- import geopandas as gpd
- import rasterio
- from rasterio.mask import mask
- import matplotlib.pyplot as plt
- import numpy as np
- import json
- from matplotlib.colors import ListedColormap, BoundaryNorm
- from rasterio.plot import show
- import seaborn as sns
- # 添加项目根目录到路径
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- import config
- class Visualizer:
- """
- 可视化器
- 负责创建栅格地图和直方图
- """
-
- def __init__(self):
- """
- 初始化可视化器
- """
- self.logger = logging.getLogger(__name__)
- self._setup_matplotlib()
-
- def _setup_matplotlib(self):
- """
- 设置matplotlib的字体和样式
- """
- try:
- # 设置字体,优先尝试常用的中文字体
- import matplotlib.font_manager as fm
-
- # 清理matplotlib字体缓存(解决Windows系统字体问题)
- try:
- import matplotlib
- fm._rebuild()
- self.logger.info("matplotlib字体缓存已重建")
- except Exception as cache_error:
- self.logger.warning(f"字体缓存重建失败: {cache_error}")
-
- # 可用的中文字体列表(Windows系统优先)
- chinese_fonts = [
- 'Microsoft YaHei', # 微软雅黑 (Windows)
- 'Microsoft YaHei UI', # 微软雅黑UI (Windows)
- 'SimHei', # 黑体 (Windows)
- 'SimSun', # 宋体 (Windows)
- 'KaiTi', # 楷体 (Windows)
- 'FangSong', # 仿宋 (Windows)
- 'Microsoft JhengHei', # 微软正黑体 (Windows)
- 'PingFang SC', # 苹方(macOS)
- 'Hiragino Sans GB', # 冬青黑体(macOS)
- 'WenQuanYi Micro Hei', # 文泉驿微米黑(Linux)
- 'Noto Sans CJK SC', # 思源黑体(Linux)
- 'Arial Unicode MS', # Unicode字体
- 'DejaVu Sans' # 备用字体
- ]
-
- # 查找可用的字体
- available_fonts = [f.name for f in fm.fontManager.ttflist]
- selected_font = None
-
- self.logger.info(f"系统中可用字体数量: {len(available_fonts)}")
-
- for font in chinese_fonts:
- if font in available_fonts:
- selected_font = font
- self.logger.info(f"选择字体: {font}")
- break
-
- if selected_font:
- plt.rcParams['font.sans-serif'] = [selected_font] + chinese_fonts
- plt.rcParams['font.family'] = 'sans-serif'
- else:
- self.logger.warning("未找到合适的中文字体,将使用系统默认字体")
- # 使用更安全的字体配置
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
- plt.rcParams['font.family'] = 'sans-serif'
-
- # 解决负号显示问题
- plt.rcParams['axes.unicode_minus'] = False
-
- # 设置图形样式
- plt.rcParams['figure.figsize'] = (10, 8)
- plt.rcParams['axes.labelsize'] = 12
- plt.rcParams['axes.titlesize'] = 14
- plt.rcParams['xtick.labelsize'] = 10
- plt.rcParams['ytick.labelsize'] = 10
- plt.rcParams['legend.fontsize'] = 10
-
- # 设置DPI以提高图像质量
- plt.rcParams['figure.dpi'] = 100
- plt.rcParams['savefig.dpi'] = 300
- plt.rcParams['savefig.bbox'] = 'tight'
- plt.rcParams['savefig.pad_inches'] = 0.1
-
- self.logger.info("matplotlib字体和样式设置完成")
-
- except Exception as e:
- self.logger.warning(f"设置matplotlib字体失败: {str(e)},将使用默认配置")
- # 最基本的安全配置
- plt.rcParams['font.family'] = 'sans-serif'
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial']
- plt.rcParams['axes.unicode_minus'] = False
-
- def create_raster_map(self,
- shp_path=None,
- tif_path=None,
- color_map_name=None,
- title_name="Prediction Cd",
- output_path=None,
- output_size=None,
- high_res=False):
- """
- 创建栅格地图
-
- @param shp_path: 输入的矢量数据的路径
- @param tif_path: 输入的栅格数据的路径
- @param color_map_name: 使用的色彩方案
- @param title_name: 输出数据的图的名称
- @param output_path: 输出保存的图片的路径
- @param output_size: 图片尺寸
- @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
- @return: 输出图片路径
- """
- try:
- # 使用默认值
- if shp_path is None:
- shp_path = config.ANALYSIS_CONFIG["boundary_shp"]
- if tif_path is None:
- tif_path = config.ANALYSIS_CONFIG["output_raster"]
- if color_map_name is None:
- color_map_name = config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]]
- if output_path is None:
- output_path = os.path.join(config.OUTPUT_PATHS["figures_dir"], "Prediction_results")
- if output_size is None:
- output_size = config.VISUALIZATION_CONFIG["figure_size"]
-
- self.logger.info(f"开始创建栅格地图: {tif_path}")
-
- # 检查文件是否存在
- if not os.path.exists(tif_path):
- raise FileNotFoundError(f"栅格文件不存在: {tif_path}")
-
- # 如果边界文件不存在,创建一个简单的边界
- if not os.path.exists(shp_path):
- self.logger.warning(f"边界文件不存在: {shp_path},将跳过边界绘制")
- gdf = None
- else:
- gdf = gpd.read_file(shp_path)
-
- # 读取并处理栅格数据
- with rasterio.open(tif_path) as src:
- if gdf is not None:
- # 使用边界裁剪栅格数据
- geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]]
- out_image, out_transform = mask(src, geoms, crop=True)
- out_meta = src.meta.copy()
- else:
- # 直接读取整个栅格
- out_image = src.read()
- out_transform = src.transform
- out_meta = src.meta.copy()
-
- # 提取数据并处理无效值
- raster = out_image[0].astype('float32')
- nodata = out_meta.get("nodata", None)
- if nodata is not None:
- raster[raster == nodata] = np.nan
-
- # 根据分位数分为6个等级
- valid_data = raster[~np.isnan(raster)]
- if len(valid_data) == 0:
- raise ValueError("栅格数据中没有有效值")
-
- bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100])
- norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1)
- cmap = ListedColormap(color_map_name)
-
- # 绘图
- fig, ax = plt.subplots(figsize=(output_size, output_size))
- show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
-
- # 添加矢量边界
- if gdf is not None:
- gdf.boundary.plot(ax=ax, color='black', linewidth=1)
-
- # 设置标题和标签
- ax.set_title(title_name, fontsize=20)
- ax.set_xlabel("Longitude", fontsize=18)
- ax.set_ylabel("Latitude", fontsize=18)
- ax.grid(True, linestyle='--', color='gray', alpha=0.5)
- ax.tick_params(axis='y', labelrotation=90)
-
- # 添加色带
- tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)]
- cbar = plt.colorbar(
- plt.cm.ScalarMappable(norm=norm, cmap=cmap),
- ax=ax,
- ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)],
- shrink=0.6, # 缩小色带高度
- aspect=15 # 细长效果
- )
- cbar.ax.set_yticklabels(tick_labels)
- cbar.set_label("Values")
-
- plt.tight_layout()
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
-
- # 保存图片
- output_file = f"{output_path}.jpg"
- # 根据high_res参数决定使用的DPI
- output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
- plt.savefig(output_file, dpi=output_dpi, format='jpg', bbox_inches='tight')
- plt.close()
-
- self.logger.info(f"栅格地图创建成功: {output_file}")
- return output_file
-
- except Exception as e:
- self.logger.error(f"栅格地图创建失败: {str(e)}")
- raise
-
- def create_histogram(self,
- file_path=None,
- figsize=None,
- xlabel='Cd content',
- ylabel='Frequency',
- title='County level Cd Frequency',
- save_path=None,
- high_res=False):
- """
- 绘制GeoTIFF文件的直方图
-
- @param file_path: GeoTIFF 文件路径
- @param figsize: 图像尺寸,如 (10, 6)
- @param xlabel: 横坐标标签
- @param ylabel: 纵坐标标签
- @param title: 图标题
- @param save_path: 可选,保存图片路径(含文件名和扩展名)
- @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
- @return: 输出图片路径
- """
- try:
- # 使用默认值
- if file_path is None:
- file_path = config.ANALYSIS_CONFIG["output_raster"]
- if figsize is None:
- figsize = (6, 6)
- if save_path is None:
- save_path = os.path.join(config.OUTPUT_PATHS["figures_dir"], "Prediction_frequency.jpg")
-
- self.logger.info(f"开始创建直方图: {file_path}")
-
- # 检查文件是否存在
- if not os.path.exists(file_path):
- raise FileNotFoundError(f"栅格文件不存在: {file_path}")
-
- # 设置seaborn样式
- sns.set(style='ticks')
-
- # 读取栅格数据
- with rasterio.open(file_path) as src:
- band = src.read(1)
- nodata = src.nodata
-
- # 处理无效值
- if nodata is not None:
- band = np.where(band == nodata, np.nan, band)
-
- # 展平数据并移除NaN值
- band_flat = band.flatten()
- band_flat = band_flat[~np.isnan(band_flat)]
-
- if len(band_flat) == 0:
- raise ValueError("栅格数据中没有有效值")
-
- # 创建图形
- plt.figure(figsize=figsize)
-
- # 绘制直方图和密度曲线
- sns.histplot(band_flat, bins=100, color='steelblue', alpha=0.7,
- edgecolor='black', stat='density')
- sns.kdeplot(band_flat, color='red', linewidth=2)
-
- # 设置标签和标题
- plt.xlabel(xlabel, fontsize=14)
- plt.ylabel(ylabel, fontsize=14)
- plt.title(title, fontsize=16)
- plt.grid(True, linestyle='--', alpha=0.5)
- plt.tight_layout()
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
-
- # 保存图片
- # 根据high_res参数决定使用的DPI
- output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
- plt.savefig(save_path, dpi=output_dpi,
- format='jpg', bbox_inches='tight')
- plt.close()
-
- self.logger.info(f"直方图创建成功: {save_path}")
- return save_path
-
- except Exception as e:
- self.logger.error(f"直方图创建失败: {str(e)}")
- raise
- if __name__ == "__main__":
- # 测试代码
- visualizer = Visualizer()
-
- # 测试栅格地图创建
- try:
- map_output = visualizer.create_raster_map()
- print(f"栅格地图创建完成: {map_output}")
- except Exception as e:
- print(f"栅格地图创建失败: {e}")
-
- # 测试直方图创建
- try:
- histogram_output = visualizer.create_histogram()
- print(f"直方图创建完成: {histogram_output}")
- except Exception as e:
- print(f"直方图创建失败: {e}")
|