data_processing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. 数据处理模块
  3. Data Processing Module
  4. 用于整合模型预测结果与坐标数据,生成最终的分析数据
  5. """
  6. import os
  7. import sys
  8. import logging
  9. import pandas as pd
  10. import numpy as np
  11. # 添加项目根目录到路径
  12. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  13. import config
  14. class DataProcessor:
  15. """
  16. 数据处理器
  17. 负责整合预测结果与坐标数据
  18. """
  19. def __init__(self):
  20. """
  21. 初始化数据处理器
  22. """
  23. self.logger = logging.getLogger(__name__)
  24. def load_predictions(self):
  25. """
  26. 加载模型预测结果
  27. @return: 包含预测结果的字典
  28. """
  29. try:
  30. predictions = {}
  31. # 加载作物Cd预测结果
  32. crop_cd_path = os.path.join(
  33. config.DATA_PATHS["predictions_dir"],
  34. config.CROP_CD_MODEL["output_file"]
  35. )
  36. if os.path.exists(crop_cd_path):
  37. predictions['crop_cd'] = pd.read_csv(crop_cd_path)
  38. self.logger.info(f"作物Cd预测结果加载成功: {crop_cd_path}")
  39. # 加载有效态Cd预测结果
  40. effective_cd_path = os.path.join(
  41. config.DATA_PATHS["predictions_dir"],
  42. config.EFFECTIVE_CD_MODEL["output_file"]
  43. )
  44. if os.path.exists(effective_cd_path):
  45. predictions['effective_cd'] = pd.read_csv(effective_cd_path)
  46. self.logger.info(f"有效态Cd预测结果加载成功: {effective_cd_path}")
  47. return predictions
  48. except Exception as e:
  49. self.logger.error(f"预测结果加载失败: {str(e)}")
  50. raise
  51. def load_coordinates(self):
  52. """
  53. 加载坐标数据
  54. @return: 坐标数据DataFrame
  55. """
  56. try:
  57. # 首先尝试从数据目录加载
  58. coord_path = os.path.join(
  59. config.DATA_PATHS["coordinates_dir"],
  60. config.DATA_PATHS["coordinate_file"]
  61. )
  62. if not os.path.exists(coord_path):
  63. # 如果数据目录中没有,尝试从原始模型目录复制
  64. self.logger.info("坐标文件不存在,尝试从原始模型目录复制...")
  65. self._copy_coordinates_from_models()
  66. coordinates = pd.read_csv(coord_path)
  67. self.logger.info(f"坐标数据加载成功: {coord_path}, 数据形状: {coordinates.shape}")
  68. return coordinates
  69. except Exception as e:
  70. self.logger.error(f"坐标数据加载失败: {str(e)}")
  71. raise
  72. def _copy_coordinates_from_models(self):
  73. """
  74. 从原始模型目录复制坐标文件
  75. """
  76. try:
  77. # 从作物Cd模型目录复制坐标文件
  78. source_path = os.path.join(
  79. "..", "作物Cd模型文件与数据", "作物Cd模型文件与数据",
  80. config.DATA_PATHS["coordinate_file"]
  81. )
  82. if os.path.exists(source_path):
  83. target_path = os.path.join(
  84. config.DATA_PATHS["coordinates_dir"],
  85. config.DATA_PATHS["coordinate_file"]
  86. )
  87. # 确保目标目录存在
  88. os.makedirs(os.path.dirname(target_path), exist_ok=True)
  89. # 复制文件
  90. import shutil
  91. shutil.copy2(source_path, target_path)
  92. self.logger.info(f"坐标文件复制成功: {source_path} -> {target_path}")
  93. else:
  94. raise FileNotFoundError(f"坐标文件不存在: {source_path}")
  95. except Exception as e:
  96. self.logger.error(f"坐标文件复制失败: {str(e)}")
  97. raise
  98. def _detect_coordinate_columns(self, coordinates):
  99. """
  100. 自动检测坐标数据的列名
  101. @param coordinates: 坐标数据DataFrame
  102. @return: (经度列名, 纬度列名)
  103. """
  104. columns = coordinates.columns.tolist()
  105. # 可能的经度列名
  106. lon_candidates = ['lon', 'longitude', '经度', 'x', 'lng']
  107. # 可能的纬度列名
  108. lat_candidates = ['lat', 'latitude', '纬度', 'y', 'lan']
  109. lon_col = None
  110. lat_col = None
  111. # 查找经度列
  112. for col in columns:
  113. for candidate in lon_candidates:
  114. if candidate.lower() in col.lower():
  115. lon_col = col
  116. break
  117. if lon_col:
  118. break
  119. # 查找纬度列
  120. for col in columns:
  121. for candidate in lat_candidates:
  122. if candidate.lower() in col.lower():
  123. lat_col = col
  124. break
  125. if lat_col:
  126. break
  127. if not lon_col or not lat_col:
  128. raise ValueError(f"无法识别坐标列。现有列名: {columns}")
  129. self.logger.info(f"检测到坐标列: 经度={lon_col}, 纬度={lat_col}")
  130. return lon_col, lat_col
  131. def combine_predictions_with_coordinates(self):
  132. """
  133. 将预测结果与坐标数据合并,为每个模型生成独立的最终数据文件
  134. @return: 最终数据文件路径字典
  135. """
  136. try:
  137. self.logger.info("开始整合预测结果与坐标数据...")
  138. # 加载数据
  139. predictions = self.load_predictions()
  140. coordinates = self.load_coordinates()
  141. # 自动检测坐标列名
  142. lon_col, lat_col = self._detect_coordinate_columns(coordinates)
  143. final_files = {}
  144. # 为每个模型创建独立的最终数据文件
  145. for model_name, prediction_data in predictions.items():
  146. # 创建单独的数据DataFrame
  147. final_data = coordinates[[lon_col, lat_col]].copy()
  148. final_data['Prediction'] = prediction_data.iloc[:, 0]
  149. # 重命名列以匹配期望的格式
  150. column_mapping = {
  151. lon_col: 'longitude',
  152. lat_col: 'latitude'
  153. }
  154. final_data = final_data.rename(columns=column_mapping)
  155. # 确定输出文件名
  156. if model_name == 'crop_cd':
  157. output_filename = "Final_predictions_crop_cd.csv"
  158. model_display_name = "作物Cd模型"
  159. elif model_name == 'effective_cd':
  160. output_filename = "Final_predictions_effective_cd.csv"
  161. model_display_name = "有效态Cd模型"
  162. else:
  163. output_filename = f"Final_predictions_{model_name}.csv"
  164. model_display_name = f"{model_name}模型"
  165. # 保存最终数据
  166. output_path = os.path.join(
  167. config.DATA_PATHS["final_dir"],
  168. output_filename
  169. )
  170. # 确保输出目录存在
  171. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  172. final_data.to_csv(output_path, index=False)
  173. final_files[model_name] = output_path
  174. self.logger.info(f"{model_display_name}数据整合完成,文件保存至: {output_path}")
  175. self.logger.info(f"{model_display_name}数据形状: {final_data.shape}")
  176. # 打印统计信息
  177. pred_stats = final_data['Prediction'].describe()
  178. self.logger.info(f"{model_display_name}预测值统计:\n{pred_stats}")
  179. # 如果同时有两个模型,也创建一个合并的文件供参考
  180. if len(predictions) > 1:
  181. combined_data = coordinates[[lon_col, lat_col]].copy()
  182. # 添加所有预测结果
  183. for model_name, prediction_data in predictions.items():
  184. if model_name == 'crop_cd':
  185. combined_data['Crop_Cd_Prediction'] = prediction_data.iloc[:, 0]
  186. elif model_name == 'effective_cd':
  187. combined_data['Effective_Cd_Prediction'] = prediction_data.iloc[:, 0]
  188. # 创建一个平均预测列
  189. prediction_columns = [col for col in combined_data.columns if 'Prediction' in col]
  190. if len(prediction_columns) > 0:
  191. combined_data['Average_Prediction'] = combined_data[prediction_columns].mean(axis=1)
  192. # 重命名坐标列
  193. column_mapping = {
  194. lon_col: 'longitude',
  195. lat_col: 'latitude'
  196. }
  197. combined_data = combined_data.rename(columns=column_mapping)
  198. # 保存合并文件
  199. combined_output_path = os.path.join(
  200. config.DATA_PATHS["final_dir"],
  201. "Final_predictions_combined.csv"
  202. )
  203. combined_data.to_csv(combined_output_path, index=False)
  204. final_files['combined'] = combined_output_path
  205. self.logger.info(f"合并数据文件保存至: {combined_output_path}")
  206. self.logger.info(f"合并数据形状: {combined_data.shape}")
  207. return final_files
  208. except Exception as e:
  209. self.logger.error(f"数据整合失败: {str(e)}")
  210. raise
  211. def validate_final_data(self, file_path):
  212. """
  213. 验证最终数据的格式和内容
  214. @param file_path: 最终数据文件路径
  215. @return: 验证是否通过
  216. """
  217. try:
  218. data = pd.read_csv(file_path)
  219. # 检查必要的列
  220. required_columns = ['longitude', 'latitude', 'Prediction']
  221. missing_columns = [col for col in required_columns if col not in data.columns]
  222. if missing_columns:
  223. self.logger.error(f"缺少必要的列: {missing_columns}")
  224. return False
  225. # 检查数据完整性
  226. if data.isnull().any().any():
  227. self.logger.warning("数据中存在空值")
  228. # 检查坐标范围(假设为合理的地理坐标)
  229. if not (data['longitude'].between(-180, 180).all() and
  230. data['latitude'].between(-90, 90).all()):
  231. self.logger.warning("坐标值超出合理范围")
  232. self.logger.info("数据验证通过")
  233. return True
  234. except Exception as e:
  235. self.logger.error(f"数据验证失败: {str(e)}")
  236. return False
  237. if __name__ == "__main__":
  238. # 测试代码
  239. processor = DataProcessor()
  240. final_files = processor.combine_predictions_with_coordinates()
  241. for model_name, final_file in final_files.items():
  242. processor.validate_final_data(final_file)
  243. print(f"{model_name}数据处理完成,最终文件: {final_file}")