mapping.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """
  2. 栅格映射模块
  3. Raster Mapping Module
  4. 基于通用绘图模块 app.utils.mapping_utils 的封装
  5. 提供与原有接口兼容的栅格转换功能,包含空间插值处理
  6. """
  7. import os
  8. import sys
  9. import logging
  10. import pandas as pd
  11. import geopandas as gpd
  12. from shapely.geometry import Point
  13. import rasterio
  14. from rasterio.features import rasterize
  15. from rasterio.transform import from_origin
  16. import numpy as np
  17. # 添加项目根目录到路径
  18. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  19. # 导入通用绘图模块
  20. from app.utils.mapping_utils import MappingUtils
  21. import config
  22. class RasterMapper:
  23. """
  24. 栅格映射器
  25. 负责将CSV数据转换为GeoTIFF栅格数据
  26. """
  27. def __init__(self):
  28. """
  29. 初始化栅格映射器
  30. """
  31. self.logger = logging.getLogger(__name__)
  32. # 初始化通用绘图模块
  33. self.mapping_utils = MappingUtils()
  34. def csv_to_shapefile(self, csv_file, shapefile_output):
  35. """
  36. 将CSV文件转换为Shapefile文件
  37. @param csv_file: CSV文件路径
  38. @param shapefile_output: 输出Shapefile文件路径
  39. """
  40. try:
  41. # 读取CSV数据
  42. df = pd.read_csv(csv_file)
  43. # 确保列名正确
  44. if 'longitude' not in df.columns or 'latitude' not in df.columns:
  45. # 尝试自动识别经纬度列
  46. lon_col = None
  47. lat_col = None
  48. for col in df.columns:
  49. col_lower = col.lower()
  50. if any(keyword in col_lower for keyword in ['lon', '经度', 'x']):
  51. lon_col = col
  52. elif any(keyword in col_lower for keyword in ['lat', '纬度', 'y']):
  53. lat_col = col
  54. if lon_col and lat_col:
  55. df = df.rename(columns={lon_col: 'longitude', lat_col: 'latitude'})
  56. self.logger.info(f"自动识别坐标列: {lon_col} -> longitude, {lat_col} -> latitude")
  57. else:
  58. raise ValueError("无法识别经纬度列")
  59. # 创建几何对象
  60. lon = df['longitude']
  61. lat = df['latitude']
  62. # 获取预测值列
  63. value_col = 'Prediction'
  64. if value_col not in df.columns:
  65. # 寻找可能的预测值列
  66. for col in df.columns:
  67. if col not in ['longitude', 'latitude'] and pd.api.types.is_numeric_dtype(df[col]):
  68. value_col = col
  69. break
  70. else:
  71. raise ValueError("无法找到预测值列")
  72. val = df[value_col]
  73. # 创建Point几何对象
  74. geometry = [Point(xy) for xy in zip(lon, lat)]
  75. # 创建GeoDataFrame
  76. gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")
  77. # 确保输出目录存在
  78. os.makedirs(os.path.dirname(shapefile_output), exist_ok=True)
  79. # 保存为Shapefile
  80. gdf.to_file(shapefile_output, driver="ESRI Shapefile")
  81. self.logger.info(f"Shapefile创建成功: {shapefile_output}")
  82. return shapefile_output
  83. except Exception as e:
  84. self.logger.error(f"CSV转Shapefile失败: {str(e)}")
  85. raise
  86. def vector_to_raster(self, input_shapefile, template_tif, output_tif, field='Prediction', resolution_factor=16.0):
  87. """
  88. 将点矢量数据转换为栅格数据(使用通用绘图模块)
  89. @param input_shapefile: 输入点矢量数据的Shapefile文件路径
  90. @param template_tif: 用作模板的GeoTIFF文件路径
  91. @param output_tif: 输出栅格化后的GeoTIFF文件路径
  92. @param field: 用于栅格化的属性字段名
  93. @param resolution_factor: 分辨率倍数因子,16.0表示分辨率提高16倍(像素单元变为1/16)
  94. """
  95. try:
  96. self.logger.info(f"开始矢量转栅格: {input_shapefile}")
  97. self.logger.info(f"分辨率因子: {resolution_factor}")
  98. # 获取边界文件
  99. boundary_shp = config.ANALYSIS_CONFIG.get("boundary_shp")
  100. # 调用通用绘图模块的矢量转栅格方法(包含空间插值)
  101. output_path, stats = self.mapping_utils.vector_to_raster(
  102. input_shapefile=input_shapefile,
  103. template_tif=template_tif,
  104. output_tif=output_tif,
  105. field=field,
  106. resolution_factor=resolution_factor,
  107. boundary_shp=boundary_shp,
  108. interpolation_method='nearest'
  109. )
  110. self.logger.info(f"栅格文件创建成功: {output_path}")
  111. if stats:
  112. self.logger.info(f"统计信息: 有效像素 {stats.get('valid_pixels', 0)}/{stats.get('total_pixels', 0)}")
  113. return output_path
  114. except Exception as e:
  115. self.logger.error(f"矢量转栅格失败: {str(e)}")
  116. raise
  117. def _create_default_template(self, gdf, template_tif):
  118. """
  119. 创建默认的模板栅格文件
  120. @param gdf: GeoDataFrame
  121. @param template_tif: 模板文件路径
  122. """
  123. try:
  124. # 获取边界
  125. bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
  126. # 设置分辨率(约1公里)
  127. resolution = 0.01 # 度
  128. # 计算栅格尺寸
  129. width = int((bounds[2] - bounds[0]) / resolution)
  130. height = int((bounds[3] - bounds[1]) / resolution)
  131. # 创建变换矩阵
  132. transform = from_origin(bounds[0], bounds[3], resolution, resolution)
  133. # 创建空的栅格数据
  134. data = np.ones((height, width), dtype='float32') * np.nan
  135. # 元数据
  136. meta = {
  137. 'driver': 'GTiff',
  138. 'dtype': 'float32',
  139. 'nodata': np.nan,
  140. 'width': width,
  141. 'height': height,
  142. 'count': 1,
  143. 'crs': gdf.crs,
  144. 'transform': transform
  145. }
  146. # 确保目录存在
  147. os.makedirs(os.path.dirname(template_tif), exist_ok=True)
  148. # 写入模板文件
  149. with rasterio.open(template_tif, 'w', **meta) as dst:
  150. dst.write(data, 1)
  151. self.logger.info(f"默认模板创建成功: {template_tif}")
  152. except Exception as e:
  153. self.logger.error(f"默认模板创建失败: {str(e)}")
  154. raise
  155. def csv_to_raster(self, csv_file, output_raster=None, output_shp=None):
  156. """
  157. 完整的CSV到栅格转换流程
  158. @param csv_file: 输入CSV文件路径
  159. @param output_raster: 输出栅格文件路径,如果为None则使用默认路径
  160. @param output_shp: 输出shapefile路径,如果为None则使用默认路径
  161. @return: 输出栅格文件路径
  162. """
  163. try:
  164. self.logger.info(f"开始CSV到栅格转换: {csv_file}")
  165. # 使用默认路径或提供的自定义路径
  166. if output_raster is None:
  167. output_raster = config.ANALYSIS_CONFIG["output_raster"]
  168. if output_shp is None:
  169. shapefile_path = config.ANALYSIS_CONFIG["temp_shapefile"]
  170. else:
  171. shapefile_path = output_shp
  172. template_tif = config.ANALYSIS_CONFIG["template_tif"]
  173. # 步骤1: CSV转Shapefile
  174. self.csv_to_shapefile(csv_file, shapefile_path)
  175. # 步骤2: Shapefile转栅格
  176. self.vector_to_raster(shapefile_path, template_tif, output_raster)
  177. self.logger.info(f"CSV到栅格转换完成: {output_raster}")
  178. return output_raster
  179. except Exception as e:
  180. self.logger.error(f"CSV到栅格转换失败: {str(e)}")
  181. raise
  182. if __name__ == "__main__":
  183. # 测试代码
  184. mapper = RasterMapper()
  185. # 假设有一个测试CSV文件
  186. test_csv = os.path.join(config.DATA_PATHS["final_dir"], "Final_predictions.csv")
  187. if os.path.exists(test_csv):
  188. output_raster = mapper.csv_to_raster(test_csv)
  189. print(f"栅格转换完成,输出文件: {output_raster}")
  190. else:
  191. print(f"测试文件不存在: {test_csv}")