visualization.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. """
  2. 可视化模块
  3. Visualization Module
  4. 基于原始01_Figure_raster_mapping.py改进,用于生成栅格地图和直方图
  5. """
  6. import os
  7. import sys
  8. import logging
  9. import geopandas as gpd
  10. import rasterio
  11. from rasterio.mask import mask
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. import json
  15. from matplotlib.colors import ListedColormap, BoundaryNorm
  16. from rasterio.plot import show
  17. import seaborn as sns
  18. # 添加项目根目录到路径
  19. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  20. import config
  21. class Visualizer:
  22. """
  23. 可视化器
  24. 负责创建栅格地图和直方图
  25. """
  26. def __init__(self):
  27. """
  28. 初始化可视化器
  29. """
  30. self.logger = logging.getLogger(__name__)
  31. self._setup_matplotlib()
  32. def _setup_matplotlib(self):
  33. """
  34. 设置matplotlib的字体和样式
  35. """
  36. try:
  37. # 设置字体,优先尝试常用的中文字体
  38. import matplotlib.font_manager as fm
  39. # 清理matplotlib字体缓存(解决Windows系统字体问题)
  40. try:
  41. import matplotlib
  42. fm._rebuild()
  43. self.logger.info("matplotlib字体缓存已重建")
  44. except Exception as cache_error:
  45. self.logger.warning(f"字体缓存重建失败: {cache_error}")
  46. # 可用的中文字体列表(Windows系统优先)
  47. chinese_fonts = [
  48. 'Microsoft YaHei', # 微软雅黑 (Windows)
  49. 'Microsoft YaHei UI', # 微软雅黑UI (Windows)
  50. 'SimHei', # 黑体 (Windows)
  51. 'SimSun', # 宋体 (Windows)
  52. 'KaiTi', # 楷体 (Windows)
  53. 'FangSong', # 仿宋 (Windows)
  54. 'Microsoft JhengHei', # 微软正黑体 (Windows)
  55. 'PingFang SC', # 苹方(macOS)
  56. 'Hiragino Sans GB', # 冬青黑体(macOS)
  57. 'WenQuanYi Micro Hei', # 文泉驿微米黑(Linux)
  58. 'Noto Sans CJK SC', # 思源黑体(Linux)
  59. 'Arial Unicode MS', # Unicode字体
  60. 'DejaVu Sans' # 备用字体
  61. ]
  62. # 查找可用的字体
  63. available_fonts = [f.name for f in fm.fontManager.ttflist]
  64. selected_font = None
  65. self.logger.info(f"系统中可用字体数量: {len(available_fonts)}")
  66. for font in chinese_fonts:
  67. if font in available_fonts:
  68. selected_font = font
  69. self.logger.info(f"选择字体: {font}")
  70. break
  71. if selected_font:
  72. plt.rcParams['font.sans-serif'] = [selected_font] + chinese_fonts
  73. plt.rcParams['font.family'] = 'sans-serif'
  74. else:
  75. self.logger.warning("未找到合适的中文字体,将使用系统默认字体")
  76. # 使用更安全的字体配置
  77. plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
  78. plt.rcParams['font.family'] = 'sans-serif'
  79. # 解决负号显示问题
  80. plt.rcParams['axes.unicode_minus'] = False
  81. # 设置图形样式
  82. plt.rcParams['figure.figsize'] = (10, 8)
  83. plt.rcParams['axes.labelsize'] = 12
  84. plt.rcParams['axes.titlesize'] = 14
  85. plt.rcParams['xtick.labelsize'] = 10
  86. plt.rcParams['ytick.labelsize'] = 10
  87. plt.rcParams['legend.fontsize'] = 10
  88. # 设置DPI以提高图像质量
  89. plt.rcParams['figure.dpi'] = 100
  90. plt.rcParams['savefig.dpi'] = 300
  91. plt.rcParams['savefig.bbox'] = 'tight'
  92. plt.rcParams['savefig.pad_inches'] = 0.1
  93. self.logger.info("matplotlib字体和样式设置完成")
  94. except Exception as e:
  95. self.logger.warning(f"设置matplotlib字体失败: {str(e)},将使用默认配置")
  96. # 最基本的安全配置
  97. plt.rcParams['font.family'] = 'sans-serif'
  98. plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial']
  99. plt.rcParams['axes.unicode_minus'] = False
  100. def create_raster_map(self,
  101. shp_path=None,
  102. tif_path=None,
  103. color_map_name=None,
  104. title_name="Prediction Cd",
  105. output_path=None,
  106. output_size=None,
  107. high_res=False):
  108. """
  109. 创建栅格地图
  110. @param shp_path: 输入的矢量数据的路径
  111. @param tif_path: 输入的栅格数据的路径
  112. @param color_map_name: 使用的色彩方案
  113. @param title_name: 输出数据的图的名称
  114. @param output_path: 输出保存的图片的路径
  115. @param output_size: 图片尺寸
  116. @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
  117. @return: 输出图片路径
  118. """
  119. try:
  120. # 使用默认值
  121. if shp_path is None:
  122. shp_path = config.ANALYSIS_CONFIG["boundary_shp"]
  123. if tif_path is None:
  124. tif_path = config.ANALYSIS_CONFIG["output_raster"]
  125. if color_map_name is None:
  126. color_map_name = config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]]
  127. if output_path is None:
  128. output_path = os.path.join(config.OUTPUT_PATHS["figures_dir"], "Prediction_results")
  129. if output_size is None:
  130. output_size = config.VISUALIZATION_CONFIG["figure_size"]
  131. self.logger.info(f"开始创建栅格地图: {tif_path}")
  132. # 检查文件是否存在
  133. if not os.path.exists(tif_path):
  134. raise FileNotFoundError(f"栅格文件不存在: {tif_path}")
  135. # 如果边界文件不存在,创建一个简单的边界
  136. if not os.path.exists(shp_path):
  137. self.logger.warning(f"边界文件不存在: {shp_path},将跳过边界绘制")
  138. gdf = None
  139. else:
  140. gdf = gpd.read_file(shp_path)
  141. # 读取并处理栅格数据
  142. with rasterio.open(tif_path) as src:
  143. if gdf is not None:
  144. # 使用边界裁剪栅格数据
  145. geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]]
  146. out_image, out_transform = mask(src, geoms, crop=True)
  147. out_meta = src.meta.copy()
  148. else:
  149. # 直接读取整个栅格
  150. out_image = src.read()
  151. out_transform = src.transform
  152. out_meta = src.meta.copy()
  153. # 提取数据并处理无效值
  154. raster = out_image[0].astype('float32')
  155. nodata = out_meta.get("nodata", None)
  156. if nodata is not None:
  157. raster[raster == nodata] = np.nan
  158. # 根据分位数分为6个等级
  159. valid_data = raster[~np.isnan(raster)]
  160. if len(valid_data) == 0:
  161. raise ValueError("栅格数据中没有有效值")
  162. bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100])
  163. norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1)
  164. cmap = ListedColormap(color_map_name)
  165. # 绘图
  166. fig, ax = plt.subplots(figsize=(output_size, output_size))
  167. show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
  168. # 添加矢量边界
  169. if gdf is not None:
  170. gdf.boundary.plot(ax=ax, color='black', linewidth=1)
  171. # 设置标题和标签
  172. ax.set_title(title_name, fontsize=20)
  173. ax.set_xlabel("Longitude", fontsize=18)
  174. ax.set_ylabel("Latitude", fontsize=18)
  175. ax.grid(True, linestyle='--', color='gray', alpha=0.5)
  176. ax.tick_params(axis='y', labelrotation=90)
  177. # 添加色带
  178. tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)]
  179. cbar = plt.colorbar(
  180. plt.cm.ScalarMappable(norm=norm, cmap=cmap),
  181. ax=ax,
  182. ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)],
  183. shrink=0.6, # 缩小色带高度
  184. aspect=15 # 细长效果
  185. )
  186. cbar.ax.set_yticklabels(tick_labels)
  187. cbar.set_label("Values")
  188. plt.tight_layout()
  189. # 确保输出目录存在
  190. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  191. # 保存图片
  192. output_file = f"{output_path}.jpg"
  193. # 根据high_res参数决定使用的DPI
  194. output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
  195. plt.savefig(output_file, dpi=output_dpi, format='jpg', bbox_inches='tight')
  196. plt.close()
  197. self.logger.info(f"栅格地图创建成功: {output_file}")
  198. return output_file
  199. except Exception as e:
  200. self.logger.error(f"栅格地图创建失败: {str(e)}")
  201. raise
  202. def create_histogram(self,
  203. file_path=None,
  204. figsize=None,
  205. xlabel='Cd content',
  206. ylabel='Frequency',
  207. title='County level Cd Frequency',
  208. save_path=None,
  209. high_res=False):
  210. """
  211. 绘制GeoTIFF文件的直方图
  212. @param file_path: GeoTIFF 文件路径
  213. @param figsize: 图像尺寸,如 (10, 6)
  214. @param xlabel: 横坐标标签
  215. @param ylabel: 纵坐标标签
  216. @param title: 图标题
  217. @param save_path: 可选,保存图片路径(含文件名和扩展名)
  218. @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
  219. @return: 输出图片路径
  220. """
  221. try:
  222. # 使用默认值
  223. if file_path is None:
  224. file_path = config.ANALYSIS_CONFIG["output_raster"]
  225. if figsize is None:
  226. figsize = (6, 6)
  227. if save_path is None:
  228. save_path = os.path.join(config.OUTPUT_PATHS["figures_dir"], "Prediction_frequency.jpg")
  229. self.logger.info(f"开始创建直方图: {file_path}")
  230. # 检查文件是否存在
  231. if not os.path.exists(file_path):
  232. raise FileNotFoundError(f"栅格文件不存在: {file_path}")
  233. # 设置seaborn样式
  234. sns.set(style='ticks')
  235. # 读取栅格数据
  236. with rasterio.open(file_path) as src:
  237. band = src.read(1)
  238. nodata = src.nodata
  239. # 处理无效值
  240. if nodata is not None:
  241. band = np.where(band == nodata, np.nan, band)
  242. # 展平数据并移除NaN值
  243. band_flat = band.flatten()
  244. band_flat = band_flat[~np.isnan(band_flat)]
  245. if len(band_flat) == 0:
  246. raise ValueError("栅格数据中没有有效值")
  247. # 创建图形
  248. plt.figure(figsize=figsize)
  249. # 绘制直方图和密度曲线
  250. sns.histplot(band_flat, bins=100, color='steelblue', alpha=0.7,
  251. edgecolor='black', stat='density')
  252. sns.kdeplot(band_flat, color='red', linewidth=2)
  253. # 设置标签和标题
  254. plt.xlabel(xlabel, fontsize=14)
  255. plt.ylabel(ylabel, fontsize=14)
  256. plt.title(title, fontsize=16)
  257. plt.grid(True, linestyle='--', alpha=0.5)
  258. plt.tight_layout()
  259. # 确保输出目录存在
  260. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  261. # 保存图片
  262. # 根据high_res参数决定使用的DPI
  263. output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
  264. plt.savefig(save_path, dpi=output_dpi,
  265. format='jpg', bbox_inches='tight')
  266. plt.close()
  267. self.logger.info(f"直方图创建成功: {save_path}")
  268. return save_path
  269. except Exception as e:
  270. self.logger.error(f"直方图创建失败: {str(e)}")
  271. raise
  272. if __name__ == "__main__":
  273. # 测试代码
  274. visualizer = Visualizer()
  275. # 测试栅格地图创建
  276. try:
  277. map_output = visualizer.create_raster_map()
  278. print(f"栅格地图创建完成: {map_output}")
  279. except Exception as e:
  280. print(f"栅格地图创建失败: {e}")
  281. # 测试直方图创建
  282. try:
  283. histogram_output = visualizer.create_histogram()
  284. print(f"直方图创建完成: {histogram_output}")
  285. except Exception as e:
  286. print(f"直方图创建失败: {e}")