data_processing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. 数据处理模块
  3. Data Processing Module
  4. 用于整合模型预测结果与坐标数据,生成最终的分析数据
  5. """
  6. import os
  7. import sys
  8. import logging
  9. import pandas as pd
  10. import numpy as np
  11. # 添加项目根目录到路径
  12. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  13. import config
  14. class DataProcessor:
  15. """
  16. 数据处理器
  17. 负责整合预测结果与坐标数据
  18. """
  19. def __init__(self):
  20. """
  21. 初始化数据处理器
  22. """
  23. self.logger = logging.getLogger(__name__)
  24. def load_predictions(self):
  25. """
  26. 加载模型预测结果(根据WORKFLOW_CONFIG配置)
  27. @return: 包含预测结果的字典
  28. """
  29. try:
  30. predictions = {}
  31. # 动态读取当前的工作流配置(运行时可能被修改)
  32. workflow_config = self._get_current_workflow_config()
  33. self.logger.info(f"当前工作流配置: {workflow_config}")
  34. # 只加载在工作流配置中启用的模型的预测结果
  35. # 加载作物Cd预测结果
  36. if workflow_config.get("run_crop_model", False):
  37. crop_cd_path = os.path.join(
  38. config.DATA_PATHS["predictions_dir"],
  39. config.CROP_CD_MODEL["output_file"]
  40. )
  41. if os.path.exists(crop_cd_path):
  42. predictions['crop_cd'] = pd.read_csv(crop_cd_path)
  43. self.logger.info(f"作物Cd预测结果加载成功: {crop_cd_path}")
  44. else:
  45. self.logger.warning(f"作物Cd预测文件不存在: {crop_cd_path}")
  46. else:
  47. self.logger.info("跳过作物Cd预测结果加载(工作流配置中未启用)")
  48. # 加载有效态Cd预测结果
  49. if workflow_config.get("run_effective_model", False):
  50. effective_cd_path = os.path.join(
  51. config.DATA_PATHS["predictions_dir"],
  52. config.EFFECTIVE_CD_MODEL["output_file"]
  53. )
  54. if os.path.exists(effective_cd_path):
  55. predictions['effective_cd'] = pd.read_csv(effective_cd_path)
  56. self.logger.info(f"有效态Cd预测结果加载成功: {effective_cd_path}")
  57. else:
  58. self.logger.warning(f"有效态Cd预测文件不存在: {effective_cd_path}")
  59. else:
  60. self.logger.info("跳过有效态Cd预测结果加载(工作流配置中未启用)")
  61. if not predictions:
  62. self.logger.warning("没有加载到任何预测结果,请检查WORKFLOW_CONFIG配置和预测文件是否存在")
  63. else:
  64. self.logger.info(f"根据工作流配置,成功加载了 {len(predictions)} 个模型的预测结果: {list(predictions.keys())}")
  65. return predictions
  66. except Exception as e:
  67. self.logger.error(f"预测结果加载失败: {str(e)}")
  68. raise
  69. def _get_current_workflow_config(self):
  70. """
  71. 动态读取当前的工作流配置
  72. @return: 当前的工作流配置字典
  73. """
  74. try:
  75. config_file = os.path.join(config.PROJECT_ROOT, "config.py")
  76. # 读取配置文件内容
  77. with open(config_file, 'r', encoding='utf-8') as f:
  78. config_content = f.read()
  79. # 提取WORKFLOW_CONFIG
  80. import re
  81. pattern = r'WORKFLOW_CONFIG\s*=\s*(\{[^}]*\})'
  82. match = re.search(pattern, config_content)
  83. if match:
  84. # 使用eval安全地解析配置(这里是安全的,因为我们控制配置文件内容)
  85. workflow_config_str = match.group(1)
  86. workflow_config = eval(workflow_config_str)
  87. return workflow_config
  88. else:
  89. self.logger.warning("无法从配置文件中提取WORKFLOW_CONFIG,使用默认配置")
  90. return config.WORKFLOW_CONFIG
  91. except Exception as e:
  92. self.logger.error(f"读取工作流配置失败: {str(e)},使用默认配置")
  93. return config.WORKFLOW_CONFIG
  94. def load_coordinates(self):
  95. """
  96. 加载坐标数据
  97. @return: 坐标数据DataFrame
  98. """
  99. try:
  100. # 首先尝试从数据目录加载
  101. coord_path = os.path.join(
  102. config.DATA_PATHS["coordinates_dir"],
  103. config.DATA_PATHS["coordinate_file"]
  104. )
  105. if not os.path.exists(coord_path):
  106. # 如果数据目录中没有,尝试从原始模型目录复制
  107. self.logger.info("坐标文件不存在,尝试从原始模型目录复制...")
  108. self._copy_coordinates_from_models()
  109. coordinates = pd.read_csv(coord_path)
  110. self.logger.info(f"坐标数据加载成功: {coord_path}, 数据形状: {coordinates.shape}")
  111. return coordinates
  112. except Exception as e:
  113. self.logger.error(f"坐标数据加载失败: {str(e)}")
  114. raise
  115. def _copy_coordinates_from_models(self):
  116. """
  117. 从原始模型目录复制坐标文件
  118. """
  119. try:
  120. # 从作物Cd模型目录复制坐标文件
  121. source_path = os.path.join(
  122. "..", "作物Cd模型文件与数据", "作物Cd模型文件与数据",
  123. config.DATA_PATHS["coordinate_file"]
  124. )
  125. if os.path.exists(source_path):
  126. target_path = os.path.join(
  127. config.DATA_PATHS["coordinates_dir"],
  128. config.DATA_PATHS["coordinate_file"]
  129. )
  130. # 确保目标目录存在
  131. os.makedirs(os.path.dirname(target_path), exist_ok=True)
  132. # 复制文件
  133. import shutil
  134. shutil.copy2(source_path, target_path)
  135. self.logger.info(f"坐标文件复制成功: {source_path} -> {target_path}")
  136. else:
  137. raise FileNotFoundError(f"坐标文件不存在: {source_path}")
  138. except Exception as e:
  139. self.logger.error(f"坐标文件复制失败: {str(e)}")
  140. raise
  141. def _detect_coordinate_columns(self, coordinates):
  142. """
  143. 自动检测坐标数据的列名
  144. @param coordinates: 坐标数据DataFrame
  145. @return: (经度列名, 纬度列名)
  146. """
  147. columns = coordinates.columns.tolist()
  148. # 可能的经度列名
  149. lon_candidates = ['lon', 'longitude', '经度', 'x', 'lng']
  150. # 可能的纬度列名
  151. lat_candidates = ['lat', 'latitude', '纬度', 'y', 'lan']
  152. lon_col = None
  153. lat_col = None
  154. # 查找经度列
  155. for col in columns:
  156. for candidate in lon_candidates:
  157. if candidate.lower() in col.lower():
  158. lon_col = col
  159. break
  160. if lon_col:
  161. break
  162. # 查找纬度列
  163. for col in columns:
  164. for candidate in lat_candidates:
  165. if candidate.lower() in col.lower():
  166. lat_col = col
  167. break
  168. if lat_col:
  169. break
  170. if not lon_col or not lat_col:
  171. raise ValueError(f"无法识别坐标列。现有列名: {columns}")
  172. self.logger.info(f"检测到坐标列: 经度={lon_col}, 纬度={lat_col}")
  173. return lon_col, lat_col
  174. def combine_predictions_with_coordinates(self):
  175. """
  176. 将预测结果与坐标数据合并,为每个模型生成独立的最终数据文件
  177. @return: 最终数据文件路径字典
  178. """
  179. try:
  180. self.logger.info("开始整合预测结果与坐标数据...")
  181. # 加载数据
  182. predictions = self.load_predictions()
  183. coordinates = self.load_coordinates()
  184. # 自动检测坐标列名
  185. lon_col, lat_col = self._detect_coordinate_columns(coordinates)
  186. final_files = {}
  187. # 为每个模型创建独立的最终数据文件
  188. for model_name, prediction_data in predictions.items():
  189. # 创建单独的数据DataFrame
  190. final_data = coordinates[[lon_col, lat_col]].copy()
  191. final_data['Prediction'] = prediction_data.iloc[:, 0]
  192. # 重命名列以匹配期望的格式
  193. column_mapping = {
  194. lon_col: 'longitude',
  195. lat_col: 'latitude'
  196. }
  197. final_data = final_data.rename(columns=column_mapping)
  198. # 确定输出文件名
  199. if model_name == 'crop_cd':
  200. output_filename = "Final_predictions_crop_cd.csv"
  201. model_display_name = "作物Cd模型"
  202. elif model_name == 'effective_cd':
  203. output_filename = "Final_predictions_effective_cd.csv"
  204. model_display_name = "有效态Cd模型"
  205. else:
  206. output_filename = f"Final_predictions_{model_name}.csv"
  207. model_display_name = f"{model_name}模型"
  208. # 保存最终数据
  209. output_path = os.path.join(
  210. config.DATA_PATHS["final_dir"],
  211. output_filename
  212. )
  213. # 确保输出目录存在
  214. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  215. final_data.to_csv(output_path, index=False)
  216. final_files[model_name] = output_path
  217. self.logger.info(f"{model_display_name}数据整合完成,文件保存至: {output_path}")
  218. self.logger.info(f"{model_display_name}数据形状: {final_data.shape}")
  219. # 打印统计信息
  220. pred_stats = final_data['Prediction'].describe()
  221. self.logger.info(f"{model_display_name}预测值统计:\n{pred_stats}")
  222. # 如果同时有两个模型,也创建一个合并的文件供参考
  223. if len(predictions) > 1:
  224. combined_data = coordinates[[lon_col, lat_col]].copy()
  225. # 添加所有预测结果
  226. for model_name, prediction_data in predictions.items():
  227. if model_name == 'crop_cd':
  228. combined_data['Crop_Cd_Prediction'] = prediction_data.iloc[:, 0]
  229. elif model_name == 'effective_cd':
  230. combined_data['Effective_Cd_Prediction'] = prediction_data.iloc[:, 0]
  231. # 创建一个平均预测列
  232. prediction_columns = [col for col in combined_data.columns if 'Prediction' in col]
  233. if len(prediction_columns) > 0:
  234. combined_data['Average_Prediction'] = combined_data[prediction_columns].mean(axis=1)
  235. # 重命名坐标列
  236. column_mapping = {
  237. lon_col: 'longitude',
  238. lat_col: 'latitude'
  239. }
  240. combined_data = combined_data.rename(columns=column_mapping)
  241. # 保存合并文件
  242. combined_output_path = os.path.join(
  243. config.DATA_PATHS["final_dir"],
  244. "Final_predictions_combined.csv"
  245. )
  246. combined_data.to_csv(combined_output_path, index=False)
  247. final_files['combined'] = combined_output_path
  248. self.logger.info(f"合并数据文件保存至: {combined_output_path}")
  249. self.logger.info(f"合并数据形状: {combined_data.shape}")
  250. return final_files
  251. except Exception as e:
  252. self.logger.error(f"数据整合失败: {str(e)}")
  253. raise
  254. def validate_final_data(self, file_path):
  255. """
  256. 验证最终数据的格式和内容
  257. @param file_path: 最终数据文件路径
  258. @return: 验证是否通过
  259. """
  260. try:
  261. data = pd.read_csv(file_path)
  262. # 检查必要的列
  263. required_columns = ['longitude', 'latitude', 'Prediction']
  264. missing_columns = [col for col in required_columns if col not in data.columns]
  265. if missing_columns:
  266. self.logger.error(f"缺少必要的列: {missing_columns}")
  267. return False
  268. # 检查数据完整性
  269. if data.isnull().any().any():
  270. self.logger.warning("数据中存在空值")
  271. # 检查坐标范围(假设为合理的地理坐标)
  272. if not (data['longitude'].between(-180, 180).all() and
  273. data['latitude'].between(-90, 90).all()):
  274. self.logger.warning("坐标值超出合理范围")
  275. self.logger.info("数据验证通过")
  276. return True
  277. except Exception as e:
  278. self.logger.error(f"数据验证失败: {str(e)}")
  279. return False
  280. if __name__ == "__main__":
  281. # 测试代码
  282. processor = DataProcessor()
  283. final_files = processor.combine_predictions_with_coordinates()
  284. for model_name, final_file in final_files.items():
  285. processor.validate_final_data(final_file)
  286. print(f"{model_name}数据处理完成,最终文件: {final_file}")