mapping_utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. """
  2. 通用绘图工具模块
  3. Universal Mapping and Visualization Utils
  4. 整合了CSV转GeoTIFF、栅格地图绘制和直方图生成功能
  5. 基于01_Transfer_csv_to_geotif.py和02_Figure_raster_mapping.py的更新代码
  6. Author: Integrated from Wanxue Zhu's code
  7. """
  8. import pandas as pd
  9. import geopandas as gpd
  10. from shapely.geometry import Point
  11. import rasterio
  12. from rasterio.features import rasterize
  13. from rasterio.transform import from_origin
  14. from rasterio.mask import mask
  15. from rasterio.plot import show
  16. import numpy as np
  17. import os
  18. import json
  19. import logging
  20. from scipy.interpolate import griddata
  21. from scipy.ndimage import distance_transform_edt
  22. import matplotlib.pyplot as plt
  23. from matplotlib.colors import ListedColormap, BoundaryNorm
  24. import seaborn as sns
  25. import warnings
  26. warnings.filterwarnings('ignore')
  27. # 配置日志
  28. logger = logging.getLogger(__name__)
  29. # 设置matplotlib的中文字体和样式
  30. plt.rcParams['font.family'] = 'Arial'
  31. plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
  32. plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] # 添加多个中文字体
  33. # 预定义的色彩方案
  34. COLORMAPS = {
  35. 'yellow_orange_brown': ['#FFFECE', '#FFF085', '#FEBA17', '#BE3D2A', '#74512D', '#4E1F00'], # 黄-橙-棕
  36. 'blue_series': ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60', '#2A3335'], # 蓝色系
  37. 'yellow_green': ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'], # 淡黄-草绿
  38. 'green_brown': ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'], # 绿色-棕色
  39. 'yellow_pink_purple': ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'], # 黄-粉-紫
  40. 'green_yellow_red_purple': ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F'], # 绿-黄-红-紫
  41. }
  42. class MappingUtils:
  43. """
  44. 通用绘图工具类
  45. 提供CSV转换、栅格处理、地图绘制和直方图生成功能
  46. """
  47. def __init__(self, log_level=logging.INFO):
  48. """
  49. 初始化绘图工具
  50. @param log_level: 日志级别
  51. """
  52. self.logger = logging.getLogger(self.__class__.__name__)
  53. self.logger.setLevel(log_level)
  54. if not self.logger.handlers:
  55. handler = logging.StreamHandler()
  56. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  57. handler.setFormatter(formatter)
  58. self.logger.addHandler(handler)
  59. def csv_to_shapefile(self, csv_file, shapefile_output, lon_col=0, lat_col=1, value_col=2):
  60. """
  61. 将CSV文件转换为Shapefile文件
  62. @param csv_file: CSV文件路径
  63. @param shapefile_output: 输出Shapefile文件路径
  64. @param lon_col: 经度列索引或列名,默认第0列
  65. @param lat_col: 纬度列索引或列名,默认第1列
  66. @param value_col: 数值列索引或列名,默认第2列
  67. @return: 输出的shapefile路径
  68. """
  69. try:
  70. self.logger.info(f"开始转换CSV到Shapefile: {csv_file}")
  71. # 读取CSV数据
  72. df = pd.read_csv(csv_file)
  73. # 支持列索引或列名
  74. if isinstance(lon_col, int):
  75. lon = df.iloc[:, lon_col]
  76. else:
  77. lon = df[lon_col]
  78. if isinstance(lat_col, int):
  79. lat = df.iloc[:, lat_col]
  80. else:
  81. lat = df[lat_col]
  82. if isinstance(value_col, int):
  83. val = df.iloc[:, value_col]
  84. else:
  85. val = df[value_col]
  86. # 创建几何对象
  87. geometry = [Point(xy) for xy in zip(lon, lat)]
  88. gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")
  89. # 确保输出目录存在
  90. os.makedirs(os.path.dirname(shapefile_output), exist_ok=True)
  91. # 保存Shapefile
  92. gdf.to_file(shapefile_output, driver="ESRI Shapefile")
  93. self.logger.info(f"✓ 成功转换CSV到Shapefile: {shapefile_output}")
  94. return shapefile_output
  95. except Exception as e:
  96. self.logger.error(f"CSV转Shapefile失败: {str(e)}")
  97. raise
  98. def create_boundary_mask(self, raster, transform, gdf):
  99. """
  100. 创建边界掩膜,只保留边界内的区域
  101. @param raster: 栅格数据
  102. @param transform: 栅格变换参数
  103. @param gdf: 矢量边界数据
  104. @return: 边界掩膜
  105. """
  106. try:
  107. mask = rasterize(
  108. gdf.geometry,
  109. out_shape=raster.shape,
  110. transform=transform,
  111. fill=0,
  112. default_value=1,
  113. dtype=np.uint8
  114. )
  115. return mask.astype(bool)
  116. except Exception as e:
  117. self.logger.error(f"创建边界掩膜失败: {str(e)}")
  118. raise
  119. def interpolate_nan_values(self, raster, method='nearest'):
  120. """
  121. 使用插值方法填充NaN值
  122. @param raster: 包含NaN值的栅格数据
  123. @param method: 插值方法 ('nearest', 'linear', 'cubic')
  124. @return: 插值后的栅格数据
  125. """
  126. try:
  127. if not np.isnan(raster).any():
  128. return raster
  129. # 获取有效值的坐标
  130. valid_mask = ~np.isnan(raster)
  131. valid_coords = np.where(valid_mask)
  132. valid_values = raster[valid_mask]
  133. if len(valid_values) == 0:
  134. self.logger.warning("没有有效值用于插值")
  135. return raster
  136. # 创建网格坐标
  137. rows, cols = raster.shape
  138. grid_x, grid_y = np.mgrid[0:rows, 0:cols]
  139. # 准备插值坐标
  140. points = np.column_stack((valid_coords[0], valid_coords[1]))
  141. # 执行插值
  142. interpolated = griddata(points, valid_values, (grid_x, grid_y),
  143. method=method, fill_value=np.nan)
  144. # 如果插值后仍有NaN值,使用最近邻方法填充
  145. if np.isnan(interpolated).any():
  146. self.logger.info(f"使用 {method} 插值后仍有NaN值,使用最近邻方法填充剩余值")
  147. remaining_nan = np.isnan(interpolated)
  148. remaining_coords = np.where(remaining_nan)
  149. if len(remaining_coords[0]) > 0:
  150. # 使用距离变换找到最近的已知值
  151. dist, indices = distance_transform_edt(remaining_nan,
  152. return_distances=True,
  153. return_indices=True)
  154. # 填充剩余的NaN值
  155. for i, j in zip(remaining_coords[0], remaining_coords[1]):
  156. if indices[0, i, j] < rows and indices[1, i, j] < cols:
  157. interpolated[i, j] = raster[indices[0, i, j], indices[1, i, j]]
  158. return interpolated
  159. except Exception as e:
  160. self.logger.error(f"插值失败: {str(e)}")
  161. return raster
  162. def vector_to_raster(self, input_shapefile, template_tif, output_tif, field,
  163. resolution_factor=16.0, boundary_shp=None, interpolation_method='nearest', enable_interpolation=True):
  164. """
  165. 将点矢量数据转换为栅格数据
  166. @param input_shapefile: 输入点矢量数据的Shapefile文件路径
  167. @param template_tif: 用作模板的GeoTIFF文件路径
  168. @param output_tif: 输出栅格化后的GeoTIFF文件路径
  169. @param field: 用于栅格化的属性字段名
  170. @param resolution_factor: 分辨率倍数因子
  171. @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜
  172. @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
  173. @param enable_interpolation: 是否启用空间插值,默认True
  174. @return: 输出的GeoTIFF文件路径和统计信息
  175. """
  176. try:
  177. self.logger.info(f"开始处理: {input_shapefile}")
  178. self.logger.info(f"分辨率因子: {resolution_factor}, 插值方法: {interpolation_method}")
  179. # 读取矢量数据
  180. gdf = gpd.read_file(input_shapefile)
  181. # 读取模板栅格
  182. with rasterio.open(template_tif) as src:
  183. template_meta = src.meta.copy()
  184. # 根据分辨率因子计算新的尺寸和变换参数
  185. if resolution_factor != 1.0:
  186. width = int(src.width * resolution_factor)
  187. height = int(src.height * resolution_factor)
  188. transform = rasterio.Affine(
  189. src.transform.a / resolution_factor,
  190. src.transform.b,
  191. src.transform.c,
  192. src.transform.d,
  193. src.transform.e / resolution_factor,
  194. src.transform.f
  195. )
  196. self.logger.info(f"分辨率调整: {src.width}x{src.height} -> {width}x{height}")
  197. else:
  198. width = src.width
  199. height = src.height
  200. transform = src.transform
  201. self.logger.info(f"保持原始分辨率: {width}x{height}")
  202. crs = src.crs
  203. # 投影矢量数据
  204. if gdf.crs != crs:
  205. gdf = gdf.to_crs(crs)
  206. # 栅格化
  207. shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[field]))
  208. raster = rasterize(
  209. shapes=shapes,
  210. out_shape=(height, width),
  211. transform=transform,
  212. fill=np.nan,
  213. dtype='float32'
  214. )
  215. # 应用边界掩膜(如果提供)
  216. if boundary_shp and os.path.exists(boundary_shp):
  217. self.logger.info(f"应用边界掩膜: {boundary_shp}")
  218. boundary_gdf = gpd.read_file(boundary_shp)
  219. if boundary_gdf.crs != crs:
  220. boundary_gdf = boundary_gdf.to_crs(crs)
  221. boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
  222. raster[~boundary_mask] = np.nan
  223. # 使用插值方法填充NaN值(如果启用)
  224. if enable_interpolation and np.isnan(raster).any():
  225. self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
  226. raster = self.interpolate_nan_values(raster, method=interpolation_method)
  227. elif not enable_interpolation and np.isnan(raster).any():
  228. self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
  229. # 创建输出目录
  230. os.makedirs(os.path.dirname(output_tif), exist_ok=True)
  231. # 更新元数据
  232. template_meta.update({
  233. "count": 1,
  234. "dtype": 'float32',
  235. "nodata": np.nan,
  236. "width": width,
  237. "height": height,
  238. "transform": transform
  239. })
  240. # 保存栅格文件
  241. with rasterio.open(output_tif, 'w', **template_meta) as dst:
  242. dst.write(raster, 1)
  243. # 计算统计信息
  244. valid_data = raster[~np.isnan(raster)]
  245. stats = None
  246. if len(valid_data) > 0:
  247. stats = {
  248. 'min': float(np.min(valid_data)),
  249. 'max': float(np.max(valid_data)),
  250. 'mean': float(np.mean(valid_data)),
  251. 'std': float(np.std(valid_data)),
  252. 'valid_pixels': int(len(valid_data)),
  253. 'total_pixels': int(raster.size)
  254. }
  255. self.logger.info(f"统计信息: 有效像素 {stats['valid_pixels']}/{stats['total_pixels']}")
  256. self.logger.info(f"数值范围: {stats['min']:.4f} - {stats['max']:.4f}")
  257. else:
  258. self.logger.warning("没有有效数据")
  259. self.logger.info(f"✓ 成功保存: {output_tif}")
  260. return output_tif, stats
  261. except Exception as e:
  262. self.logger.error(f"矢量转栅格失败: {str(e)}")
  263. raise
  264. def create_raster_map(self, shp_path, tif_path, output_path,
  265. colormap='green_yellow_red_purple', title="Prediction Map",
  266. output_size=12, figsize=None, dpi=300,
  267. resolution_factor=1.0, enable_interpolation=False,
  268. interpolation_method='nearest'):
  269. """
  270. 创建栅格地图
  271. @param shp_path: 输入的矢量数据路径
  272. @param tif_path: 输入的栅格数据路径
  273. @param output_path: 输出图片路径(不包含扩展名)
  274. @param colormap: 色彩方案名称或颜色列表
  275. @param title: 图片标题
  276. @param output_size: 图片尺寸(正方形),如果指定了figsize则忽略此参数
  277. @param figsize: 图片尺寸元组 (width, height),优先级高于output_size
  278. @param dpi: 图片分辨率
  279. @param resolution_factor: 分辨率因子,>1提高分辨率,<1降低分辨率
  280. @param enable_interpolation: 是否启用空间插值,用于处理NaN值或提高分辨率
  281. @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
  282. @param enable_interpolation: 是否启用空间插值,默认True
  283. @return: 输出图片文件路径
  284. """
  285. try:
  286. self.logger.info(f"开始创建栅格地图: {tif_path}")
  287. self.logger.info(f"分辨率因子: {resolution_factor}, 启用插值: {enable_interpolation}")
  288. # 读取矢量边界
  289. gdf = gpd.read_file(shp_path) if shp_path else None
  290. # 读取并裁剪栅格数据
  291. with rasterio.open(tif_path) as src:
  292. original_transform = src.transform
  293. original_crs = src.crs
  294. if gdf is not None:
  295. # 确保坐标系一致
  296. if gdf.crs != src.crs:
  297. gdf = gdf.to_crs(src.crs)
  298. # 裁剪栅格
  299. geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]]
  300. out_image, out_transform = mask(src, geoms, crop=True)
  301. out_meta = src.meta.copy()
  302. else:
  303. # 如果没有边界文件,使用整个栅格
  304. out_image = src.read()
  305. out_transform = src.transform
  306. out_meta = src.meta.copy()
  307. # 提取数据并处理无效值
  308. raster = out_image[0].astype('float32')
  309. nodata = out_meta.get("nodata", None)
  310. if nodata is not None:
  311. raster[raster == nodata] = np.nan
  312. # 应用分辨率因子重采样
  313. if resolution_factor != 1.0:
  314. self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
  315. raster, out_transform = self._resample_raster(raster, out_transform, resolution_factor, original_crs)
  316. # 应用空间插值(如果启用)
  317. if enable_interpolation and np.isnan(raster).any():
  318. self.logger.info(f"使用 {interpolation_method} 方法进行空间插值")
  319. raster = self.interpolate_nan_values(raster, method=interpolation_method)
  320. # 检查是否有有效数据
  321. if np.all(np.isnan(raster)):
  322. raise ValueError("栅格数据中没有有效值")
  323. # 根据分位数分为6个等级
  324. bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100])
  325. norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1)
  326. # 获取色彩方案
  327. if isinstance(colormap, str):
  328. if colormap in COLORMAPS:
  329. color_list = COLORMAPS[colormap]
  330. else:
  331. self.logger.warning(f"未知色彩方案: {colormap},使用默认方案")
  332. color_list = COLORMAPS['green_yellow_red_purple']
  333. else:
  334. color_list = colormap
  335. cmap = ListedColormap(color_list)
  336. # 设置图片尺寸
  337. if figsize is not None:
  338. fig_size = figsize
  339. else:
  340. fig_size = (output_size, output_size)
  341. # 绘图
  342. fig, ax = plt.subplots(figsize=fig_size)
  343. show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
  344. # 添加矢量边界
  345. if gdf is not None:
  346. gdf.boundary.plot(ax=ax, color='black', linewidth=1)
  347. # 设置标题和标签
  348. ax.set_title(title, fontsize=20)
  349. ax.set_xlabel("Longitude", fontsize=18)
  350. ax.set_ylabel("Latitude", fontsize=18)
  351. ax.grid(True, linestyle='--', color='gray', alpha=0.5)
  352. ax.tick_params(axis='y', labelrotation=90)
  353. # 添加色带
  354. tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)]
  355. cbar = plt.colorbar(
  356. plt.cm.ScalarMappable(norm=norm, cmap=cmap),
  357. ax=ax,
  358. ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)],
  359. shrink=0.6,
  360. aspect=15
  361. )
  362. cbar.ax.set_yticklabels(tick_labels)
  363. cbar.set_label("Values")
  364. plt.tight_layout()
  365. # 确保输出目录存在
  366. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  367. # 保存图片
  368. output_file = f"{output_path}.jpg"
  369. plt.savefig(output_file, dpi=dpi, format='jpg', bbox_inches='tight')
  370. plt.close()
  371. self.logger.info(f"✓ 栅格地图创建成功: {output_file}")
  372. return output_file
  373. except Exception as e:
  374. self.logger.error(f"栅格地图创建失败: {str(e)}")
  375. raise
  376. def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
  377. xlabel='像元值', ylabel='频率密度', title='数值分布图',
  378. bins=100, dpi=300):
  379. """
  380. 绘制GeoTIFF文件的直方图
  381. @param file_path: GeoTIFF文件路径
  382. @param save_path: 保存路径,如果为None则自动生成
  383. @param figsize: 图像尺寸
  384. @param xlabel: 横坐标标签
  385. @param ylabel: 纵坐标标签
  386. @param title: 图标题
  387. @param bins: 直方图箱数
  388. @param dpi: 图片分辨率
  389. @return: 输出图片文件路径
  390. """
  391. try:
  392. self.logger.info(f"开始创建直方图: {file_path}")
  393. # 设置seaborn样式
  394. sns.set(style='ticks')
  395. # 读取栅格数据
  396. with rasterio.open(file_path) as src:
  397. band = src.read(1)
  398. nodata = src.nodata
  399. # 处理无效值
  400. if nodata is not None:
  401. band = np.where(band == nodata, np.nan, band)
  402. # 展平数据并移除NaN值
  403. band_flat = band.flatten()
  404. band_flat = band_flat[~np.isnan(band_flat)]
  405. if len(band_flat) == 0:
  406. raise ValueError("栅格数据中没有有效值")
  407. # 创建图形
  408. plt.figure(figsize=figsize)
  409. # 绘制直方图和密度曲线
  410. sns.histplot(band_flat, bins=bins, color='steelblue', alpha=0.7,
  411. edgecolor='black', stat='density')
  412. sns.kdeplot(band_flat, color='red', linewidth=2)
  413. # 设置标签和标题
  414. plt.xlabel(xlabel, fontsize=14)
  415. plt.ylabel(ylabel, fontsize=14)
  416. plt.title(title, fontsize=16)
  417. plt.grid(True, linestyle='--', alpha=0.5)
  418. plt.tight_layout()
  419. # 保存图片
  420. if save_path is None:
  421. save_path = file_path.replace('.tif', '_histogram.jpg')
  422. # 确保输出目录存在
  423. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  424. plt.savefig(save_path, dpi=dpi, format='jpg', bbox_inches='tight')
  425. plt.close()
  426. self.logger.info(f"✓ 直方图创建成功: {save_path}")
  427. return save_path
  428. except Exception as e:
  429. self.logger.error(f"直方图创建失败: {str(e)}")
  430. raise
  431. def get_available_colormaps():
  432. """
  433. 获取可用的色彩方案列表
  434. @return: 色彩方案字典
  435. """
  436. return COLORMAPS.copy()
  437. def csv_to_raster_workflow(csv_file, template_tif, output_dir,
  438. boundary_shp=None, resolution_factor=16.0,
  439. interpolation_method='nearest', field_name='Prediction',
  440. lon_col=0, lat_col=1, value_col=2, enable_interpolation=False):
  441. """
  442. 完整的CSV到栅格转换工作流
  443. @param csv_file: CSV文件路径
  444. @param template_tif: 模板GeoTIFF文件路径
  445. @param output_dir: 输出目录
  446. @param boundary_shp: 边界Shapefile文件路径(可选)
  447. @param resolution_factor: 分辨率因子
  448. @param interpolation_method: 插值方法
  449. @param field_name: 字段名称
  450. @param lon_col: 经度列
  451. @param lat_col: 纬度列
  452. @param value_col: 数值列
  453. @param enable_interpolation: 是否启用空间插值,默认False
  454. @return: 输出文件路径字典
  455. """
  456. mapper = MappingUtils()
  457. # 确保输出目录存在
  458. os.makedirs(output_dir, exist_ok=True)
  459. # 生成文件名
  460. base_name = os.path.splitext(os.path.basename(csv_file))[0]
  461. shapefile_path = os.path.join(output_dir, f"{base_name}_points.shp")
  462. raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
  463. # 1. CSV转Shapefile
  464. mapper.csv_to_shapefile(csv_file, shapefile_path, lon_col, lat_col, value_col)
  465. # 2. Shapefile转栅格
  466. raster_path, stats = mapper.vector_to_raster(
  467. shapefile_path, template_tif, raster_path, field_name,
  468. resolution_factor, boundary_shp, interpolation_method, enable_interpolation
  469. )
  470. return {
  471. 'shapefile': shapefile_path,
  472. 'raster': raster_path,
  473. 'statistics': stats
  474. }