mapping_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. """
  2. 通用绘图工具模块
  3. Universal Mapping and Visualization Utils
  4. 整合了CSV转GeoTIFF、栅格地图绘制和直方图生成功能
  5. 基于01_Transfer_csv_to_geotif.py和02_Figure_raster_mapping.py的更新代码
  6. Author: Integrated from Wanxue Zhu's code
  7. """
  8. import pandas as pd
  9. import geopandas as gpd
  10. from shapely.geometry import Point
  11. import rasterio
  12. from rasterio.features import rasterize, geometry_mask
  13. from rasterio.transform import from_origin
  14. from rasterio.mask import mask
  15. from rasterio.plot import show
  16. from rasterio.warp import transform_bounds, reproject
  17. from rasterio.enums import Resampling
  18. import numpy as np
  19. import os
  20. import json
  21. import logging
  22. from scipy.interpolate import griddata
  23. from scipy.ndimage import distance_transform_edt
  24. import matplotlib.pyplot as plt
  25. from matplotlib.colors import ListedColormap, BoundaryNorm
  26. import seaborn as sns
  27. import warnings
  28. warnings.filterwarnings('ignore')
  29. # 配置日志
  30. logger = logging.getLogger(__name__)
  31. # 设置matplotlib的中文字体和样式
  32. plt.rcParams['font.family'] = 'Arial'
  33. plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
  34. plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] # 添加多个中文字体
  35. # 预定义的色彩方案
  36. COLORMAPS = {
  37. 'yellow_orange_brown': ['#FFFECE', '#FFF085', '#FEBA17', '#BE3D2A', '#74512D', '#4E1F00'], # 黄-橙-棕
  38. 'blue_series': ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60', '#2A3335'], # 蓝色系
  39. 'yellow_green': ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'], # 淡黄-草绿
  40. 'green_brown': ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'], # 绿色-棕色
  41. 'yellow_pink_purple': ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'], # 黄-粉-紫
  42. 'green_yellow_red_purple': ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F'], # 绿-黄-红-紫
  43. }
  44. class MappingUtils:
  45. """
  46. 通用绘图工具类
  47. 提供CSV转换、栅格处理、地图绘制和直方图生成功能
  48. """
  49. def __init__(self, log_level=logging.INFO):
  50. """
  51. 初始化绘图工具
  52. @param log_level: 日志级别
  53. """
  54. self.logger = logging.getLogger(self.__class__.__name__)
  55. self.logger.setLevel(log_level)
  56. if not self.logger.handlers:
  57. handler = logging.StreamHandler()
  58. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  59. handler.setFormatter(formatter)
  60. self.logger.addHandler(handler)
  61. def csv_to_shapefile(self, csv_file, shapefile_output, lon_col=0, lat_col=1, value_col=2):
  62. """
  63. 将CSV文件转换为Shapefile文件
  64. @param csv_file: CSV文件路径
  65. @param shapefile_output: 输出Shapefile文件路径
  66. @param lon_col: 经度列索引或列名,默认第0列
  67. @param lat_col: 纬度列索引或列名,默认第1列
  68. @param value_col: 数值列索引或列名,默认第2列
  69. @return: 输出的shapefile路径
  70. """
  71. try:
  72. self.logger.info(f"开始转换CSV到Shapefile: {csv_file}")
  73. # 读取CSV数据
  74. df = pd.read_csv(csv_file)
  75. # 支持列索引或列名
  76. if isinstance(lon_col, int):
  77. lon = df.iloc[:, lon_col]
  78. else:
  79. lon = df[lon_col]
  80. if isinstance(lat_col, int):
  81. lat = df.iloc[:, lat_col]
  82. else:
  83. lat = df[lat_col]
  84. if isinstance(value_col, int):
  85. val = df.iloc[:, value_col]
  86. else:
  87. val = df[value_col]
  88. # 创建几何对象
  89. geometry = [Point(xy) for xy in zip(lon, lat)]
  90. gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")
  91. # 确保输出目录存在
  92. os.makedirs(os.path.dirname(shapefile_output), exist_ok=True)
  93. # 保存Shapefile
  94. gdf.to_file(shapefile_output, driver="ESRI Shapefile")
  95. self.logger.info(f"✓ 成功转换CSV到Shapefile: {shapefile_output}")
  96. return shapefile_output
  97. except Exception as e:
  98. self.logger.error(f"CSV转Shapefile失败: {str(e)}")
  99. raise
  100. def create_boundary_mask(self, raster, transform, gdf):
  101. """
  102. 创建边界掩膜,只保留边界内的区域
  103. @param raster: 栅格数据
  104. @param transform: 栅格变换参数
  105. @param gdf: 矢量边界数据
  106. @return: 边界掩膜
  107. """
  108. try:
  109. mask = rasterize(
  110. gdf.geometry,
  111. out_shape=raster.shape,
  112. transform=transform,
  113. fill=0,
  114. default_value=1,
  115. dtype=np.uint8
  116. )
  117. return mask.astype(bool)
  118. except Exception as e:
  119. self.logger.error(f"创建边界掩膜失败: {str(e)}")
  120. raise
  121. def interpolate_nan_values(self, raster, method='nearest'):
  122. """
  123. 使用插值方法填充NaN值
  124. @param raster: 包含NaN值的栅格数据
  125. @param method: 插值方法 ('nearest', 'linear', 'cubic')
  126. @return: 插值后的栅格数据
  127. """
  128. try:
  129. if not np.isnan(raster).any():
  130. return raster
  131. # 获取有效值的坐标
  132. valid_mask = ~np.isnan(raster)
  133. valid_coords = np.where(valid_mask)
  134. valid_values = raster[valid_mask]
  135. if len(valid_values) == 0:
  136. self.logger.warning("没有有效值用于插值")
  137. return raster
  138. # 创建网格坐标
  139. rows, cols = raster.shape
  140. grid_x, grid_y = np.mgrid[0:rows, 0:cols]
  141. # 准备插值坐标
  142. points = np.column_stack((valid_coords[0], valid_coords[1]))
  143. # 执行插值
  144. interpolated = griddata(points, valid_values, (grid_x, grid_y),
  145. method=method, fill_value=np.nan)
  146. # 如果插值后仍有NaN值,使用最近邻方法填充
  147. if np.isnan(interpolated).any():
  148. self.logger.info(f"使用 {method} 插值后仍有NaN值,使用最近邻方法填充剩余值")
  149. remaining_nan = np.isnan(interpolated)
  150. remaining_coords = np.where(remaining_nan)
  151. if len(remaining_coords[0]) > 0:
  152. # 使用距离变换找到最近的已知值
  153. dist, indices = distance_transform_edt(remaining_nan,
  154. return_distances=True,
  155. return_indices=True)
  156. # 填充剩余的NaN值
  157. for i, j in zip(remaining_coords[0], remaining_coords[1]):
  158. if indices[0, i, j] < rows and indices[1, i, j] < cols:
  159. interpolated[i, j] = raster[indices[0, i, j], indices[1, i, j]]
  160. return interpolated
  161. except Exception as e:
  162. self.logger.error(f"插值失败: {str(e)}")
  163. return raster
  164. def vector_to_raster(self, input_shapefile, template_tif, output_tif, field,
  165. resolution_factor=16.0, boundary_shp=None, interpolation_method='nearest', enable_interpolation=True):
  166. """
  167. 将点矢量数据转换为栅格数据
  168. @param input_shapefile: 输入点矢量数据的Shapefile文件路径
  169. @param template_tif: 用作模板的GeoTIFF文件路径
  170. @param output_tif: 输出栅格化后的GeoTIFF文件路径
  171. @param field: 用于栅格化的属性字段名
  172. @param resolution_factor: 分辨率倍数因子
  173. @param boundary_shp: 边界Shapefile文件路径,用于创建掩膜
  174. @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
  175. @param enable_interpolation: 是否启用空间插值,默认True
  176. @return: 输出的GeoTIFF文件路径和统计信息
  177. """
  178. try:
  179. self.logger.info(f"开始处理: {input_shapefile}")
  180. self.logger.info(f"分辨率因子: {resolution_factor}, 插值方法: {interpolation_method}")
  181. # 读取矢量数据
  182. gdf = gpd.read_file(input_shapefile)
  183. # 读取模板栅格
  184. with rasterio.open(template_tif) as src:
  185. template_meta = src.meta.copy()
  186. # 根据分辨率因子计算新的尺寸和变换参数
  187. if resolution_factor != 1.0:
  188. width = int(src.width * resolution_factor)
  189. height = int(src.height * resolution_factor)
  190. transform = rasterio.Affine(
  191. src.transform.a / resolution_factor,
  192. src.transform.b,
  193. src.transform.c,
  194. src.transform.d,
  195. src.transform.e / resolution_factor,
  196. src.transform.f
  197. )
  198. self.logger.info(f"分辨率调整: {src.width}x{src.height} -> {width}x{height}")
  199. else:
  200. width = src.width
  201. height = src.height
  202. transform = src.transform
  203. self.logger.info(f"保持原始分辨率: {width}x{height}")
  204. crs = src.crs
  205. # 投影矢量数据
  206. if gdf.crs != crs:
  207. gdf = gdf.to_crs(crs)
  208. # 栅格化
  209. shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[field]))
  210. raster = rasterize(
  211. shapes=shapes,
  212. out_shape=(height, width),
  213. transform=transform,
  214. fill=np.nan,
  215. dtype='float32'
  216. )
  217. # 预备掩膜:优先使用行政区边界;若未提供边界,则使用点集凸包限制绘制范围
  218. boundary_mask = None
  219. if boundary_shp and os.path.exists(boundary_shp):
  220. self.logger.info(f"应用边界掩膜: {boundary_shp}")
  221. boundary_gdf = gpd.read_file(boundary_shp)
  222. if boundary_gdf.crs != crs:
  223. boundary_gdf = boundary_gdf.to_crs(crs)
  224. boundary_mask = self.create_boundary_mask(raster, transform, boundary_gdf)
  225. raster[~boundary_mask] = np.nan
  226. else:
  227. try:
  228. # 使用点集凸包作为默认掩膜,避免边界外着色
  229. hull = gdf.unary_union.convex_hull
  230. hull_gdf = gpd.GeoDataFrame(geometry=[hull], crs=crs)
  231. boundary_mask = self.create_boundary_mask(raster, transform, hull_gdf)
  232. raster[~boundary_mask] = np.nan
  233. self.logger.info("已使用点集凸包限制绘制范围")
  234. except Exception as hull_err:
  235. self.logger.warning(f"生成点集凸包掩膜失败,可能会出现边界外着色: {str(hull_err)}")
  236. # 使用插值方法填充NaN值(如果启用)
  237. if enable_interpolation and np.isnan(raster).any():
  238. self.logger.info(f"使用 {interpolation_method} 方法进行插值...")
  239. raster = self.interpolate_nan_values(raster, method=interpolation_method)
  240. # 关键修正:插值后再次应用掩膜,确保边界外不被填充
  241. if boundary_mask is not None:
  242. raster[~boundary_mask] = np.nan
  243. elif not enable_interpolation and np.isnan(raster).any():
  244. self.logger.info("插值已禁用,保留原始栅格数据(包含NaN值)")
  245. # 创建输出目录
  246. os.makedirs(os.path.dirname(output_tif), exist_ok=True)
  247. # 更新元数据
  248. template_meta.update({
  249. "count": 1,
  250. "dtype": 'float32',
  251. "nodata": np.nan,
  252. "width": width,
  253. "height": height,
  254. "transform": transform
  255. })
  256. # 保存栅格文件
  257. with rasterio.open(output_tif, 'w', **template_meta) as dst:
  258. dst.write(raster, 1)
  259. # 计算统计信息
  260. valid_data = raster[~np.isnan(raster)]
  261. stats = None
  262. if len(valid_data) > 0:
  263. stats = {
  264. 'min': float(np.min(valid_data)),
  265. 'max': float(np.max(valid_data)),
  266. 'mean': float(np.mean(valid_data)),
  267. 'std': float(np.std(valid_data)),
  268. 'valid_pixels': int(len(valid_data)),
  269. 'total_pixels': int(raster.size)
  270. }
  271. self.logger.info(f"统计信息: 有效像素 {stats['valid_pixels']}/{stats['total_pixels']}")
  272. self.logger.info(f"数值范围: {stats['min']:.4f} - {stats['max']:.4f}")
  273. else:
  274. self.logger.warning("没有有效数据")
  275. self.logger.info(f"✓ 成功保存: {output_tif}")
  276. return output_tif, stats
  277. except Exception as e:
  278. self.logger.error(f"矢量转栅格失败: {str(e)}")
  279. raise
  280. def create_raster_map(self, shp_path, tif_path, output_path,
  281. colormap='green_yellow_red_purple', title="Prediction Map",
  282. output_size=12, figsize=None, dpi=300,
  283. resolution_factor=1.0, enable_interpolation=True,
  284. interpolation_method='nearest'):
  285. """
  286. 创建栅格地图
  287. @param shp_path: 输入的矢量数据路径
  288. @param tif_path: 输入的栅格数据路径
  289. @param output_path: 输出图片路径(不包含扩展名)
  290. @param colormap: 色彩方案名称或颜色列表
  291. @param title: 图片标题
  292. @param output_size: 图片尺寸(正方形),如果指定了figsize则忽略此参数
  293. @param figsize: 图片尺寸元组 (width, height),优先级高于output_size
  294. @param dpi: 图片分辨率
  295. @param resolution_factor: 分辨率因子,>1提高分辨率,<1降低分辨率
  296. @param enable_interpolation: 是否启用空间插值,用于处理NaN值或提高分辨率,默认True
  297. @param interpolation_method: 插值方法 ('nearest', 'linear', 'cubic')
  298. @return: 输出图片文件路径
  299. """
  300. try:
  301. self.logger.info(f"开始创建栅格地图: {tif_path}")
  302. self.logger.info(f"分辨率因子: {resolution_factor}, 启用插值: {enable_interpolation}")
  303. # 读取矢量边界
  304. gdf = gpd.read_file(shp_path) if shp_path else None
  305. # 读取并裁剪栅格数据
  306. with rasterio.open(tif_path) as src:
  307. original_transform = src.transform
  308. original_crs = src.crs
  309. if gdf is not None:
  310. # 确保坐标系一致
  311. if gdf.crs != src.crs:
  312. gdf = gdf.to_crs(src.crs)
  313. # 裁剪栅格
  314. geoms = [json.loads(gdf.to_json())["features"][0]["geometry"]]
  315. out_image, out_transform = mask(src, geoms, crop=True)
  316. out_meta = src.meta.copy()
  317. else:
  318. # 如果没有边界文件,使用整个栅格
  319. out_image = src.read()
  320. out_transform = src.transform
  321. out_meta = src.meta.copy()
  322. # 提取数据并处理无效值
  323. raster = out_image[0].astype('float32')
  324. nodata = out_meta.get("nodata", None)
  325. if nodata is not None:
  326. raster[raster == nodata] = np.nan
  327. # 应用分辨率因子重采样
  328. if resolution_factor != 1.0:
  329. self.logger.info(f"应用分辨率因子重采样: {resolution_factor}")
  330. raster, out_transform = self._resample_raster(
  331. raster=raster,
  332. transform=out_transform,
  333. resolution_factor=resolution_factor,
  334. crs=original_crs,
  335. resampling='nearest'
  336. )
  337. # 应用空间插值(如果启用)
  338. if enable_interpolation and np.isnan(raster).any():
  339. self.logger.info(f"使用 {interpolation_method} 方法进行空间插值")
  340. raster = self.interpolate_nan_values(raster, method=interpolation_method)
  341. # 检查是否有有效数据
  342. if np.all(np.isnan(raster)):
  343. raise ValueError("栅格数据中没有有效值")
  344. # 根据分位数分为6个等级
  345. bounds = np.nanpercentile(raster, [0, 20, 40, 60, 80, 90, 100])
  346. norm = BoundaryNorm(bounds, ncolors=len(bounds) - 1)
  347. # 获取色彩方案
  348. if isinstance(colormap, str):
  349. if colormap in COLORMAPS:
  350. color_list = COLORMAPS[colormap]
  351. else:
  352. self.logger.warning(f"未知色彩方案: {colormap},使用默认方案")
  353. color_list = COLORMAPS['green_yellow_red_purple']
  354. else:
  355. color_list = colormap
  356. cmap = ListedColormap(color_list)
  357. # 设置图片尺寸
  358. if figsize is not None:
  359. fig_size = figsize
  360. else:
  361. fig_size = (output_size, output_size)
  362. # 绘图
  363. fig, ax = plt.subplots(figsize=fig_size)
  364. # 如果有边界文件,需要进一步mask边界外的区域;否则使用栅格有效范围
  365. if gdf is not None:
  366. try:
  367. height, width = raster.shape
  368. transform = out_transform
  369. geom_mask = geometry_mask(
  370. [json.loads(gdf.to_json())["features"][0]["geometry"]],
  371. out_shape=(height, width),
  372. transform=transform,
  373. invert=True
  374. )
  375. raster = np.where(geom_mask, raster, np.nan)
  376. except Exception as mask_err:
  377. self.logger.warning(f"边界掩膜应用失败,将继续绘制已裁剪栅格: {str(mask_err)}")
  378. # 显示栅格数据
  379. show(raster, transform=out_transform, ax=ax, cmap=cmap, norm=norm)
  380. # 添加矢量边界
  381. if gdf is not None:
  382. gdf.boundary.plot(ax=ax, color='black', linewidth=1)
  383. # 设置标题和标签
  384. ax.set_title(title, fontsize=20)
  385. ax.set_xlabel("Longitude", fontsize=18)
  386. ax.set_ylabel("Latitude", fontsize=18)
  387. ax.grid(True, linestyle='--', color='gray', alpha=0.5)
  388. ax.tick_params(axis='y', labelrotation=90)
  389. # 添加色带
  390. tick_labels = [f"{bounds[i]:.1f}" for i in range(len(bounds) - 1)]
  391. cbar = plt.colorbar(
  392. plt.cm.ScalarMappable(norm=norm, cmap=cmap),
  393. ax=ax,
  394. ticks=[(bounds[i] + bounds[i+1]) / 2 for i in range(len(bounds) - 1)],
  395. shrink=0.6,
  396. aspect=15
  397. )
  398. cbar.ax.set_yticklabels(tick_labels)
  399. cbar.set_label("Values")
  400. plt.tight_layout()
  401. # 确保输出目录存在
  402. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  403. # 保存图片
  404. output_file = f"{output_path}.jpg"
  405. plt.savefig(output_file, dpi=dpi, format='jpg', bbox_inches='tight')
  406. plt.close()
  407. self.logger.info(f"✓ 栅格地图创建成功: {output_file}")
  408. return output_file
  409. except Exception as e:
  410. self.logger.error(f"栅格地图创建失败: {str(e)}")
  411. raise
  412. def _resample_raster(self, raster, transform, resolution_factor: float, crs, resampling: str = 'nearest'):
  413. """
  414. 按分辨率因子对二维栅格进行重采样,并返回新栅格与更新后的仿射变换
  415. @param raster: 2D numpy 数组
  416. @param transform: 输入栅格的仿射变换
  417. @param resolution_factor: 分辨率因子 (>1 增加像元密度)
  418. @param crs: 坐标参考系
  419. @param resampling: 重采样方式 ('nearest' | 'bilinear' | 'cubic')
  420. @return: (resampled_raster, new_transform)
  421. """
  422. try:
  423. if resolution_factor == 1.0:
  424. return raster, transform
  425. rows, cols = raster.shape
  426. new_rows = max(1, int(rows * resolution_factor))
  427. new_cols = max(1, int(cols * resolution_factor))
  428. # 更新变换(像元变小)
  429. new_transform = rasterio.Affine(
  430. transform.a / resolution_factor,
  431. transform.b,
  432. transform.c,
  433. transform.d,
  434. transform.e / resolution_factor,
  435. transform.f
  436. )
  437. # 选择重采样算法
  438. resampling_map = {
  439. 'nearest': Resampling.nearest,
  440. 'bilinear': Resampling.bilinear,
  441. 'cubic': Resampling.cubic,
  442. }
  443. resampling_enum = resampling_map.get(resampling, Resampling.nearest)
  444. destination = np.full((new_rows, new_cols), np.nan, dtype='float32')
  445. # 使用 reproject 做重采样(坐标系不变,仅分辨率变化)
  446. reproject(
  447. source=raster,
  448. destination=destination,
  449. src_transform=transform,
  450. src_crs=crs,
  451. dst_transform=new_transform,
  452. dst_crs=crs,
  453. src_nodata=np.nan,
  454. dst_nodata=np.nan,
  455. resampling=resampling_enum
  456. )
  457. return destination, new_transform
  458. except Exception as e:
  459. self.logger.error(f"重采样失败: {str(e)}")
  460. # 失败则返回原始数据,避免中断
  461. return raster, transform
  462. def create_histogram(self, file_path, save_path=None, figsize=(10, 6),
  463. xlabel='像元值', ylabel='频率密度', title='数值分布图',
  464. bins=100, dpi=300):
  465. """
  466. 绘制GeoTIFF文件的直方图
  467. @param file_path: GeoTIFF文件路径
  468. @param save_path: 保存路径,如果为None则自动生成
  469. @param figsize: 图像尺寸
  470. @param xlabel: 横坐标标签
  471. @param ylabel: 纵坐标标签
  472. @param title: 图标题
  473. @param bins: 直方图箱数
  474. @param dpi: 图片分辨率
  475. @return: 输出图片文件路径
  476. """
  477. try:
  478. self.logger.info(f"开始创建直方图: {file_path}")
  479. # 设置seaborn样式
  480. sns.set(style='ticks')
  481. # 读取栅格数据
  482. with rasterio.open(file_path) as src:
  483. band = src.read(1)
  484. nodata = src.nodata
  485. # 处理无效值
  486. if nodata is not None:
  487. band = np.where(band == nodata, np.nan, band)
  488. # 展平数据并移除NaN值
  489. band_flat = band.flatten()
  490. band_flat = band_flat[~np.isnan(band_flat)]
  491. if len(band_flat) == 0:
  492. raise ValueError("栅格数据中没有有效值")
  493. # 创建图形
  494. plt.figure(figsize=figsize)
  495. # 绘制直方图和密度曲线
  496. sns.histplot(band_flat, bins=bins, color='steelblue', alpha=0.7,
  497. edgecolor='black', stat='density')
  498. sns.kdeplot(band_flat, color='red', linewidth=2)
  499. # 设置标签和标题
  500. plt.xlabel(xlabel, fontsize=14)
  501. plt.ylabel(ylabel, fontsize=14)
  502. plt.title(title, fontsize=16)
  503. plt.grid(True, linestyle='--', alpha=0.5)
  504. plt.tight_layout()
  505. # 保存图片
  506. if save_path is None:
  507. save_path = file_path.replace('.tif', '_histogram.jpg')
  508. # 确保输出目录存在
  509. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  510. plt.savefig(save_path, dpi=dpi, format='jpg', bbox_inches='tight')
  511. plt.close()
  512. self.logger.info(f"✓ 直方图创建成功: {save_path}")
  513. return save_path
  514. except Exception as e:
  515. self.logger.error(f"直方图创建失败: {str(e)}")
  516. raise
  517. def get_available_colormaps():
  518. """
  519. 获取可用的色彩方案列表
  520. @return: 色彩方案字典
  521. """
  522. return COLORMAPS.copy()
  523. def csv_to_raster_workflow(csv_file, template_tif, output_dir,
  524. boundary_shp=None, resolution_factor=16.0,
  525. interpolation_method='nearest', field_name='Prediction',
  526. lon_col=0, lat_col=1, value_col=2, enable_interpolation=False):
  527. """
  528. 完整的CSV到栅格转换工作流
  529. @param csv_file: CSV文件路径
  530. @param template_tif: 模板GeoTIFF文件路径
  531. @param output_dir: 输出目录
  532. @param boundary_shp: 边界Shapefile文件路径(可选)
  533. @param resolution_factor: 分辨率因子
  534. @param interpolation_method: 插值方法
  535. @param field_name: 字段名称
  536. @param lon_col: 经度列
  537. @param lat_col: 纬度列
  538. @param value_col: 数值列
  539. @param enable_interpolation: 是否启用空间插值,默认False
  540. @return: 输出文件路径字典
  541. """
  542. mapper = MappingUtils()
  543. # 确保输出目录存在
  544. os.makedirs(output_dir, exist_ok=True)
  545. # 生成文件名
  546. base_name = os.path.splitext(os.path.basename(csv_file))[0]
  547. shapefile_path = os.path.join(output_dir, f"{base_name}_points.shp")
  548. raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
  549. # 1. CSV转Shapefile
  550. mapper.csv_to_shapefile(csv_file, shapefile_path, lon_col, lat_col, value_col)
  551. # 2. Shapefile转栅格
  552. raster_path, stats = mapper.vector_to_raster(
  553. shapefile_path, template_tif, raster_path, field_name,
  554. resolution_factor, boundary_shp, interpolation_method, enable_interpolation
  555. )
  556. return {
  557. 'shapefile': shapefile_path,
  558. 'raster': raster_path,
  559. 'statistics': stats
  560. }