main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. 主执行脚本 - Cd预测集成系统
  3. Main execution script for Cd Prediction Integrated System
  4. Description: 集成运行作物Cd模型、有效态Cd模型和数据分析可视化的完整流程
  5. 架构简化版本:直接调用通用绘图模块,删除中间封装层
  6. """
  7. import os
  8. import sys
  9. import logging
  10. from datetime import datetime
  11. # 添加项目根目录到Python路径
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  13. # 添加上级目录到路径以访问通用绘图模块
  14. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  15. import config
  16. from models.crop_cd_model.predict import CropCdPredictor
  17. from models.effective_cd_model.predict import EffectiveCdPredictor
  18. from analysis.data_processing import DataProcessor
  19. from utils.common import setup_logging, validate_files
  20. # 直接导入通用绘图模块
  21. from app.utils.mapping_utils import csv_to_raster_workflow, MappingUtils
  22. def main():
  23. """
  24. 主执行函数
  25. 执行完整的Cd预测分析流程
  26. """
  27. # 设置日志
  28. log_file = os.path.join(config.OUTPUT_PATHS["reports_dir"],
  29. f"execution_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
  30. setup_logging(log_file)
  31. logger = logging.getLogger(__name__)
  32. logger.info("=" * 60)
  33. logger.info("开始执行Cd预测集成系统 (架构简化版本)")
  34. logger.info("=" * 60)
  35. try:
  36. # 确保目录存在
  37. config.ensure_directories()
  38. logger.info("项目目录结构检查完成")
  39. # 步骤1: 运行作物Cd模型预测
  40. if config.WORKFLOW_CONFIG["run_crop_model"]:
  41. logger.info("步骤1: 运行作物Cd模型预测...")
  42. crop_predictor = CropCdPredictor()
  43. crop_output = crop_predictor.predict()
  44. logger.info(f"作物Cd模型预测完成,输出文件: {crop_output}")
  45. # 步骤2: 运行有效态Cd模型预测
  46. if config.WORKFLOW_CONFIG["run_effective_model"]:
  47. logger.info("步骤2: 运行有效态Cd模型预测...")
  48. effective_predictor = EffectiveCdPredictor()
  49. effective_output = effective_predictor.predict()
  50. logger.info(f"有效态Cd模型预测完成,输出文件: {effective_output}")
  51. # 步骤3: 数据整合处理
  52. if config.WORKFLOW_CONFIG["combine_predictions"]:
  53. logger.info("步骤3: 整合预测结果与坐标数据...")
  54. data_processor = DataProcessor()
  55. final_data_files = data_processor.combine_predictions_with_coordinates()
  56. logger.info(f"数据整合完成,生成文件: {list(final_data_files.keys())}")
  57. # 初始化通用绘图工具(用于可视化)
  58. mapping_utils = MappingUtils()
  59. # 步骤4: 为每个模型生成栅格文件和可视化图
  60. model_outputs = {}
  61. for model_name, final_data_file in final_data_files.items():
  62. if model_name == 'combined': # 跳过合并文件的可视化
  63. continue
  64. # 只为当前工作流配置中启用的模型生成输出文件
  65. if model_name == 'crop_cd' and not config.WORKFLOW_CONFIG.get("run_crop_model", False):
  66. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  67. continue
  68. if model_name == 'effective_cd' and not config.WORKFLOW_CONFIG.get("run_effective_model", False):
  69. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  70. continue
  71. logger.info(f"为{model_name}模型生成栅格和可视化...")
  72. # 为每个模型确定显示名称
  73. if model_name == 'crop_cd':
  74. display_name = "作物Cd"
  75. title_name = "Crop Cd Prediction"
  76. elif model_name == 'effective_cd':
  77. display_name = "有效态Cd"
  78. title_name = "Effective Cd Prediction"
  79. else:
  80. display_name = model_name
  81. title_name = f"{model_name} Prediction"
  82. model_outputs[model_name] = {}
  83. # 步骤4a: 生成栅格文件(直接调用通用绘图模块)
  84. if config.WORKFLOW_CONFIG["generate_raster"]:
  85. logger.info(f"步骤4a: 为{display_name}模型生成栅格文件...")
  86. # 使用csv_to_raster_workflow直接生成栅格
  87. try:
  88. workflow_result = csv_to_raster_workflow(
  89. csv_file=final_data_file,
  90. template_tif=config.ANALYSIS_CONFIG["template_tif"],
  91. output_dir=config.OUTPUT_PATHS["raster_dir"],
  92. boundary_shp=config.ANALYSIS_CONFIG.get("boundary_shp"),
  93. resolution_factor=1.0, # 高分辨率栅格生成
  94. interpolation_method='nearest',
  95. field_name='Prediction',
  96. lon_col=0,
  97. lat_col=1,
  98. value_col=2,
  99. enable_interpolation=False # 禁用空间插值
  100. )
  101. output_raster = workflow_result['raster']
  102. # 重命名为特定模型的文件
  103. model_raster_path = os.path.join(
  104. config.OUTPUT_PATHS["raster_dir"],
  105. f"output_{model_name}.tif"
  106. )
  107. if os.path.exists(output_raster) and output_raster != model_raster_path:
  108. import shutil
  109. shutil.move(output_raster, model_raster_path)
  110. output_raster = model_raster_path
  111. model_outputs[model_name]['raster'] = output_raster
  112. logger.info(f"{display_name}模型栅格文件生成完成: {output_raster}")
  113. # 清理中间shapefile文件
  114. try:
  115. shapefile_path = workflow_result.get('shapefile')
  116. if shapefile_path and os.path.exists(shapefile_path):
  117. # 删除shapefile及其相关文件
  118. import glob
  119. base_path = os.path.splitext(shapefile_path)[0]
  120. for ext in ['.shp', '.shx', '.dbf', '.prj', '.cpg']:
  121. file_to_delete = base_path + ext
  122. if os.path.exists(file_to_delete):
  123. os.remove(file_to_delete)
  124. logger.debug(f"已删除中间文件: {file_to_delete}")
  125. logger.info(f"已清理中间shapefile文件: {shapefile_path}")
  126. except Exception as cleanup_error:
  127. logger.warning(f"清理中间文件失败: {str(cleanup_error)}")
  128. except Exception as e:
  129. logger.error(f"栅格生成失败: {str(e)}")
  130. raise
  131. # 步骤4b: 创建可视化地图(直接调用通用绘图模块)
  132. if config.WORKFLOW_CONFIG["create_visualization"]:
  133. logger.info(f"步骤4b: 为{display_name}模型创建可视化地图...")
  134. # 为每个模型创建独立的地图输出路径
  135. map_output_path = os.path.join(
  136. config.OUTPUT_PATHS["figures_dir"],
  137. f"Prediction_results_{model_name}"
  138. )
  139. try:
  140. # 直接调用通用绘图模块的create_raster_map
  141. map_output = mapping_utils.create_raster_map(
  142. shp_path=config.ANALYSIS_CONFIG.get("boundary_shp"),
  143. tif_path=output_raster,
  144. output_path=map_output_path,
  145. title=title_name,
  146. colormap=config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]],
  147. output_size=config.VISUALIZATION_CONFIG["figure_size"],
  148. dpi=config.VISUALIZATION_CONFIG["dpi"],
  149. resolution_factor=1.0, # 保持栅格原始分辨率
  150. enable_interpolation=False, # 不在可视化阶段插值
  151. interpolation_method='nearest'
  152. )
  153. model_outputs[model_name]['map'] = map_output
  154. logger.info(f"{display_name}模型地图可视化完成: {map_output}")
  155. except Exception as e:
  156. logger.error(f"地图可视化失败: {str(e)}")
  157. raise
  158. # 步骤4c: 创建直方图(直接调用通用绘图模块)
  159. if config.WORKFLOW_CONFIG["create_histogram"]:
  160. logger.info(f"步骤4c: 为{display_name}模型创建预测值分布直方图...")
  161. # 为每个模型创建独立的直方图输出路径
  162. histogram_output_path = os.path.join(
  163. config.OUTPUT_PATHS["figures_dir"],
  164. f"Prediction_frequency_{model_name}.jpg"
  165. )
  166. # 为了避免中文字体问题,使用英文标题
  167. if model_name == 'crop_cd':
  168. hist_title = 'Crop Cd Prediction Frequency'
  169. hist_xlabel = 'Crop Cd Content (mg/kg)'
  170. elif model_name == 'effective_cd':
  171. hist_title = 'Effective Cd Prediction Frequency'
  172. hist_xlabel = 'Effective Cd Content (mg/kg)'
  173. else:
  174. hist_title = f'{model_name} Prediction Frequency'
  175. hist_xlabel = f'{model_name} Content'
  176. try:
  177. # 直接调用通用绘图模块的create_histogram
  178. histogram_output = mapping_utils.create_histogram(
  179. file_path=output_raster,
  180. save_path=histogram_output_path,
  181. figsize=(6, 6),
  182. xlabel=hist_xlabel,
  183. ylabel='Frequency',
  184. title=hist_title,
  185. dpi=config.VISUALIZATION_CONFIG["dpi"]
  186. )
  187. model_outputs[model_name]['histogram'] = histogram_output
  188. logger.info(f"{display_name}模型直方图创建完成: {histogram_output}")
  189. except Exception as e:
  190. logger.error(f"直方图生成失败: {str(e)}")
  191. raise
  192. logger.info("=" * 60)
  193. logger.info("所有流程执行完成!")
  194. logger.info("=" * 60)
  195. # 打印结果摘要
  196. print_summary(model_outputs)
  197. except Exception as e:
  198. logger.error(f"执行过程中发生错误: {str(e)}")
  199. logger.exception("详细错误信息:")
  200. raise
  201. def print_summary(model_outputs):
  202. """打印执行结果摘要"""
  203. print("\n" + "=" * 60)
  204. print("[SUCCESS] Cd预测集成系统执行完成! (架构简化版本)")
  205. print("=" * 60)
  206. print(f"[INFO] 总输出目录: {config.OUTPUT_PATHS}")
  207. print()
  208. # 显示每个模型的输出文件
  209. for model_name, outputs in model_outputs.items():
  210. if model_name == 'crop_cd':
  211. display_name = "作物Cd模型"
  212. elif model_name == 'effective_cd':
  213. display_name = "有效态Cd模型"
  214. else:
  215. display_name = f"{model_name}模型"
  216. print(f"[MODEL] {display_name}:")
  217. for output_type, file_path in outputs.items():
  218. if output_type == 'raster':
  219. print(f" [RASTER] 栅格文件: {file_path}")
  220. elif output_type == 'map':
  221. print(f" [MAP] 地图文件: {file_path}")
  222. elif output_type == 'histogram':
  223. print(f" [CHART] 直方图文件: {file_path}")
  224. print()
  225. print(f"[LOG] 日志文件目录: {config.OUTPUT_PATHS['reports_dir']}")
  226. print("=" * 60)
  227. if __name__ == "__main__":
  228. main()