mapping.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """
  2. 栅格映射模块
  3. Raster Mapping Module
  4. 基于原始02_Transfer_csv_to_geotif.py改进,用于将CSV数据转换为GeoTIFF格式
  5. """
  6. import os
  7. import sys
  8. import logging
  9. import pandas as pd
  10. import geopandas as gpd
  11. from shapely.geometry import Point
  12. import rasterio
  13. from rasterio.features import rasterize
  14. from rasterio.transform import from_origin
  15. import numpy as np
  16. # 添加项目根目录到路径
  17. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  18. import config
  19. class RasterMapper:
  20. """
  21. 栅格映射器
  22. 负责将CSV数据转换为GeoTIFF栅格数据
  23. """
  24. def __init__(self):
  25. """
  26. 初始化栅格映射器
  27. """
  28. self.logger = logging.getLogger(__name__)
  29. def csv_to_shapefile(self, csv_file, shapefile_output):
  30. """
  31. 将CSV文件转换为Shapefile文件
  32. @param csv_file: CSV文件路径
  33. @param shapefile_output: 输出Shapefile文件路径
  34. """
  35. try:
  36. # 读取CSV数据
  37. df = pd.read_csv(csv_file)
  38. # 确保列名正确
  39. if 'longitude' not in df.columns or 'latitude' not in df.columns:
  40. # 尝试自动识别经纬度列
  41. lon_col = None
  42. lat_col = None
  43. for col in df.columns:
  44. col_lower = col.lower()
  45. if any(keyword in col_lower for keyword in ['lon', '经度', 'x']):
  46. lon_col = col
  47. elif any(keyword in col_lower for keyword in ['lat', '纬度', 'y']):
  48. lat_col = col
  49. if lon_col and lat_col:
  50. df = df.rename(columns={lon_col: 'longitude', lat_col: 'latitude'})
  51. self.logger.info(f"自动识别坐标列: {lon_col} -> longitude, {lat_col} -> latitude")
  52. else:
  53. raise ValueError("无法识别经纬度列")
  54. # 创建几何对象
  55. lon = df['longitude']
  56. lat = df['latitude']
  57. # 获取预测值列
  58. value_col = 'Prediction'
  59. if value_col not in df.columns:
  60. # 寻找可能的预测值列
  61. for col in df.columns:
  62. if col not in ['longitude', 'latitude'] and pd.api.types.is_numeric_dtype(df[col]):
  63. value_col = col
  64. break
  65. else:
  66. raise ValueError("无法找到预测值列")
  67. val = df[value_col]
  68. # 创建Point几何对象
  69. geometry = [Point(xy) for xy in zip(lon, lat)]
  70. # 创建GeoDataFrame
  71. gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")
  72. # 确保输出目录存在
  73. os.makedirs(os.path.dirname(shapefile_output), exist_ok=True)
  74. # 保存为Shapefile
  75. gdf.to_file(shapefile_output, driver="ESRI Shapefile")
  76. self.logger.info(f"Shapefile创建成功: {shapefile_output}")
  77. return shapefile_output
  78. except Exception as e:
  79. self.logger.error(f"CSV转Shapefile失败: {str(e)}")
  80. raise
  81. def vector_to_raster(self, input_shapefile, template_tif, output_tif, field='Prediction'):
  82. """
  83. 将点矢量数据转换为栅格数据
  84. @param input_shapefile: 输入点矢量数据的Shapefile文件路径
  85. @param template_tif: 用作模板的GeoTIFF文件路径
  86. @param output_tif: 输出栅格化后的GeoTIFF文件路径
  87. @param field: 用于栅格化的属性字段名
  88. """
  89. try:
  90. # 读取矢量数据
  91. gdf = gpd.read_file(input_shapefile)
  92. self.logger.info(f"矢量数据加载成功: {input_shapefile}")
  93. # 检查模板文件是否存在
  94. if not os.path.exists(template_tif):
  95. self.logger.warning(f"模板文件不存在: {template_tif},将创建默认栅格")
  96. self._create_default_template(gdf, template_tif)
  97. # 打开模板栅格,获取元信息
  98. with rasterio.open(template_tif) as src:
  99. template_meta = src.meta.copy()
  100. transform = src.transform
  101. width = src.width
  102. height = src.height
  103. crs = src.crs
  104. self.logger.info(f"模板栅格信息 - 宽度: {width}, 高度: {height}, CRS: {crs}")
  105. # 将矢量数据投影到模板的坐标系
  106. if gdf.crs != crs:
  107. gdf = gdf.to_crs(crs)
  108. self.logger.info(f"矢量数据投影到: {crs}")
  109. # 检查字段是否存在
  110. if field not in gdf.columns:
  111. # 尝试寻找替代字段
  112. numeric_cols = [col for col in gdf.columns if pd.api.types.is_numeric_dtype(gdf[col])]
  113. if numeric_cols:
  114. field = numeric_cols[0]
  115. self.logger.warning(f"字段'{field}'不存在,使用'{field}'替代")
  116. else:
  117. raise ValueError(f"无法找到数值字段进行栅格化")
  118. # 栅格化
  119. shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[field]))
  120. raster = rasterize(
  121. shapes=shapes,
  122. out_shape=(height, width),
  123. transform=transform,
  124. fill=np.nan,
  125. dtype='float32'
  126. )
  127. # 确保输出目录存在
  128. os.makedirs(os.path.dirname(output_tif), exist_ok=True)
  129. # 写入GeoTIFF
  130. template_meta.update({
  131. "count": 1,
  132. "dtype": 'float32',
  133. "nodata": np.nan
  134. })
  135. with rasterio.open(output_tif, 'w', **template_meta) as dst:
  136. dst.write(raster, 1)
  137. self.logger.info(f"栅格文件创建成功: {output_tif}")
  138. return output_tif
  139. except Exception as e:
  140. self.logger.error(f"矢量转栅格失败: {str(e)}")
  141. raise
  142. def _create_default_template(self, gdf, template_tif):
  143. """
  144. 创建默认的模板栅格文件
  145. @param gdf: GeoDataFrame
  146. @param template_tif: 模板文件路径
  147. """
  148. try:
  149. # 获取边界
  150. bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
  151. # 设置分辨率(约1公里)
  152. resolution = 0.01 # 度
  153. # 计算栅格尺寸
  154. width = int((bounds[2] - bounds[0]) / resolution)
  155. height = int((bounds[3] - bounds[1]) / resolution)
  156. # 创建变换矩阵
  157. transform = from_origin(bounds[0], bounds[3], resolution, resolution)
  158. # 创建空的栅格数据
  159. data = np.ones((height, width), dtype='float32') * np.nan
  160. # 元数据
  161. meta = {
  162. 'driver': 'GTiff',
  163. 'dtype': 'float32',
  164. 'nodata': np.nan,
  165. 'width': width,
  166. 'height': height,
  167. 'count': 1,
  168. 'crs': gdf.crs,
  169. 'transform': transform
  170. }
  171. # 确保目录存在
  172. os.makedirs(os.path.dirname(template_tif), exist_ok=True)
  173. # 写入模板文件
  174. with rasterio.open(template_tif, 'w', **meta) as dst:
  175. dst.write(data, 1)
  176. self.logger.info(f"默认模板创建成功: {template_tif}")
  177. except Exception as e:
  178. self.logger.error(f"默认模板创建失败: {str(e)}")
  179. raise
  180. def csv_to_raster(self, csv_file, output_raster=None, output_shp=None):
  181. """
  182. 完整的CSV到栅格转换流程
  183. @param csv_file: 输入CSV文件路径
  184. @param output_raster: 输出栅格文件路径,如果为None则使用默认路径
  185. @param output_shp: 输出shapefile路径,如果为None则使用默认路径
  186. @return: 输出栅格文件路径
  187. """
  188. try:
  189. self.logger.info(f"开始CSV到栅格转换: {csv_file}")
  190. # 使用默认路径或提供的自定义路径
  191. if output_raster is None:
  192. output_raster = config.ANALYSIS_CONFIG["output_raster"]
  193. if output_shp is None:
  194. shapefile_path = config.ANALYSIS_CONFIG["temp_shapefile"]
  195. else:
  196. shapefile_path = output_shp
  197. template_tif = config.ANALYSIS_CONFIG["template_tif"]
  198. # 步骤1: CSV转Shapefile
  199. self.csv_to_shapefile(csv_file, shapefile_path)
  200. # 步骤2: Shapefile转栅格
  201. self.vector_to_raster(shapefile_path, template_tif, output_raster)
  202. self.logger.info(f"CSV到栅格转换完成: {output_raster}")
  203. return output_raster
  204. except Exception as e:
  205. self.logger.error(f"CSV到栅格转换失败: {str(e)}")
  206. raise
  207. if __name__ == "__main__":
  208. # 测试代码
  209. mapper = RasterMapper()
  210. # 假设有一个测试CSV文件
  211. test_csv = os.path.join(config.DATA_PATHS["final_dir"], "Final_predictions.csv")
  212. if os.path.exists(test_csv):
  213. output_raster = mapper.csv_to_raster(test_csv)
  214. print(f"栅格转换完成,输出文件: {output_raster}")
  215. else:
  216. print(f"测试文件不存在: {test_csv}")