main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. 主执行脚本 - Cd预测集成系统
  3. Main execution script for Cd Prediction Integrated System
  4. Description: 集成运行作物Cd模型、有效态Cd模型和数据分析可视化的完整流程
  5. 架构简化版本:直接调用通用绘图模块,删除中间封装层
  6. """
  7. import os
  8. import sys
  9. import logging
  10. from datetime import datetime
  11. # 添加项目根目录到Python路径
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  13. # 添加上级目录到路径以访问通用绘图模块
  14. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  15. import config
  16. from models.crop_cd_model.predict import CropCdPredictor
  17. from models.effective_cd_model.predict import EffectiveCdPredictor
  18. from analysis.data_processing import DataProcessor
  19. from utils.common import setup_logging, validate_files
  20. # 直接导入通用绘图模块
  21. from app.utils.mapping_utils import csv_to_raster_workflow, MappingUtils
  22. def main():
  23. """
  24. 主执行函数
  25. 执行完整的Cd预测分析流程
  26. """
  27. # 设置日志
  28. log_file = os.path.join(config.OUTPUT_PATHS["reports_dir"],
  29. f"execution_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
  30. setup_logging(log_file)
  31. logger = logging.getLogger(__name__)
  32. logger.info("=" * 60)
  33. logger.info("开始执行Cd预测集成系统 (架构简化版本)")
  34. logger.info("=" * 60)
  35. try:
  36. # 确保目录存在
  37. config.ensure_directories()
  38. logger.info("项目目录结构检查完成")
  39. # 获取当前工作流配置(支持运行时动态配置)
  40. workflow_config = config.get_workflow_config()
  41. logger.info(f"当前工作流配置: {workflow_config}")
  42. # 步骤1: 运行作物Cd模型预测
  43. if workflow_config["run_crop_model"]:
  44. logger.info("步骤1: 运行作物Cd模型预测...")
  45. crop_predictor = CropCdPredictor()
  46. crop_output = crop_predictor.predict()
  47. logger.info(f"作物Cd模型预测完成,输出文件: {crop_output}")
  48. # 步骤2: 运行有效态Cd模型预测
  49. if workflow_config["run_effective_model"]:
  50. logger.info("步骤2: 运行有效态Cd模型预测...")
  51. effective_predictor = EffectiveCdPredictor()
  52. effective_output = effective_predictor.predict()
  53. logger.info(f"有效态Cd模型预测完成,输出文件: {effective_output}")
  54. # 步骤3: 数据整合处理
  55. if workflow_config["combine_predictions"]:
  56. logger.info("步骤3: 整合预测结果与坐标数据...")
  57. data_processor = DataProcessor()
  58. final_data_files = data_processor.combine_predictions_with_coordinates()
  59. logger.info(f"数据整合完成,生成文件: {list(final_data_files.keys())}")
  60. # 初始化通用绘图工具(用于可视化)
  61. mapping_utils = MappingUtils()
  62. # 步骤4: 为每个模型生成栅格文件和可视化图
  63. model_outputs = {}
  64. for model_name, final_data_file in final_data_files.items():
  65. if model_name == 'combined': # 跳过合并文件的可视化
  66. continue
  67. # 只为当前工作流配置中启用的模型生成输出文件
  68. if model_name == 'crop_cd' and not workflow_config.get("run_crop_model", False):
  69. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  70. continue
  71. if model_name == 'effective_cd' and not workflow_config.get("run_effective_model", False):
  72. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  73. continue
  74. logger.info(f"为{model_name}模型生成栅格和可视化...")
  75. # 为每个模型确定显示名称
  76. if model_name == 'crop_cd':
  77. display_name = "作物Cd"
  78. title_name = "Crop Cd Prediction"
  79. elif model_name == 'effective_cd':
  80. display_name = "有效态Cd"
  81. title_name = "Effective Cd Prediction"
  82. else:
  83. display_name = model_name
  84. title_name = f"{model_name} Prediction"
  85. model_outputs[model_name] = {}
  86. # 步骤4a: 生成栅格文件(直接调用通用绘图模块)
  87. if workflow_config["generate_raster"]:
  88. logger.info(f"步骤4a: 为{display_name}模型生成栅格文件...")
  89. # 使用csv_to_raster_workflow直接生成栅格
  90. try:
  91. workflow_result = csv_to_raster_workflow(
  92. csv_file=final_data_file,
  93. template_tif=config.ANALYSIS_CONFIG["template_tif"],
  94. output_dir=config.OUTPUT_PATHS["raster_dir"],
  95. boundary_shp=config.ANALYSIS_CONFIG.get("boundary_shp"),
  96. resolution_factor=1.0, # 高分辨率栅格生成
  97. interpolation_method='nearest',
  98. field_name='Prediction',
  99. lon_col=0,
  100. lat_col=1,
  101. value_col=2,
  102. enable_interpolation=False # 禁用空间插值
  103. )
  104. output_raster = workflow_result['raster']
  105. # 重命名为特定模型的文件
  106. model_raster_path = os.path.join(
  107. config.OUTPUT_PATHS["raster_dir"],
  108. f"output_{model_name}.tif"
  109. )
  110. if os.path.exists(output_raster) and output_raster != model_raster_path:
  111. import shutil
  112. shutil.move(output_raster, model_raster_path)
  113. output_raster = model_raster_path
  114. model_outputs[model_name]['raster'] = output_raster
  115. logger.info(f"{display_name}模型栅格文件生成完成: {output_raster}")
  116. # 清理中间shapefile文件
  117. try:
  118. shapefile_path = workflow_result.get('shapefile')
  119. if shapefile_path and os.path.exists(shapefile_path):
  120. # 删除shapefile及其相关文件
  121. import glob
  122. base_path = os.path.splitext(shapefile_path)[0]
  123. for ext in ['.shp', '.shx', '.dbf', '.prj', '.cpg']:
  124. file_to_delete = base_path + ext
  125. if os.path.exists(file_to_delete):
  126. os.remove(file_to_delete)
  127. logger.debug(f"已删除中间文件: {file_to_delete}")
  128. logger.info(f"已清理中间shapefile文件: {shapefile_path}")
  129. except Exception as cleanup_error:
  130. logger.warning(f"清理中间文件失败: {str(cleanup_error)}")
  131. except Exception as e:
  132. logger.error(f"栅格生成失败: {str(e)}")
  133. raise
  134. # 步骤4b: 创建可视化地图(直接调用通用绘图模块)
  135. if workflow_config["create_visualization"]:
  136. logger.info(f"步骤4b: 为{display_name}模型创建可视化地图...")
  137. # 为每个模型创建独立的地图输出路径
  138. map_output_path = os.path.join(
  139. config.OUTPUT_PATHS["figures_dir"],
  140. f"Prediction_results_{model_name}"
  141. )
  142. try:
  143. # 直接调用通用绘图模块的create_raster_map
  144. map_output = mapping_utils.create_raster_map(
  145. shp_path=config.ANALYSIS_CONFIG.get("boundary_shp"),
  146. tif_path=output_raster,
  147. output_path=map_output_path,
  148. title=title_name,
  149. colormap=config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]],
  150. output_size=config.VISUALIZATION_CONFIG["figure_size"],
  151. dpi=config.VISUALIZATION_CONFIG["dpi"],
  152. resolution_factor=1.0, # 保持栅格原始分辨率
  153. enable_interpolation=False, # 不在可视化阶段插值
  154. interpolation_method='nearest'
  155. )
  156. model_outputs[model_name]['map'] = map_output
  157. logger.info(f"{display_name}模型地图可视化完成: {map_output}")
  158. except Exception as e:
  159. logger.error(f"地图可视化失败: {str(e)}")
  160. raise
  161. # 步骤4c: 创建直方图(直接调用通用绘图模块)
  162. if workflow_config["create_histogram"]:
  163. logger.info(f"步骤4c: 为{display_name}模型创建预测值分布直方图...")
  164. # 为每个模型创建独立的直方图输出路径
  165. histogram_output_path = os.path.join(
  166. config.OUTPUT_PATHS["figures_dir"],
  167. f"Prediction_frequency_{model_name}.jpg"
  168. )
  169. # 为了避免中文字体问题,使用英文标题
  170. if model_name == 'crop_cd':
  171. hist_title = 'Crop Cd Prediction Frequency'
  172. hist_xlabel = 'Crop Cd Content (mg/kg)'
  173. elif model_name == 'effective_cd':
  174. hist_title = 'Effective Cd Prediction Frequency'
  175. hist_xlabel = 'Effective Cd Content (mg/kg)'
  176. else:
  177. hist_title = f'{model_name} Prediction Frequency'
  178. hist_xlabel = f'{model_name} Content'
  179. try:
  180. # 直接调用通用绘图模块的create_histogram
  181. histogram_output = mapping_utils.create_histogram(
  182. file_path=output_raster,
  183. save_path=histogram_output_path,
  184. figsize=(6, 6),
  185. xlabel=hist_xlabel,
  186. ylabel='Frequency',
  187. title=hist_title,
  188. dpi=config.VISUALIZATION_CONFIG["dpi"]
  189. )
  190. model_outputs[model_name]['histogram'] = histogram_output
  191. logger.info(f"{display_name}模型直方图创建完成: {histogram_output}")
  192. except Exception as e:
  193. logger.error(f"直方图生成失败: {str(e)}")
  194. raise
  195. logger.info("=" * 60)
  196. logger.info("所有流程执行完成!")
  197. logger.info("=" * 60)
  198. # 打印结果摘要
  199. print_summary(model_outputs)
  200. except Exception as e:
  201. logger.error(f"执行过程中发生错误: {str(e)}")
  202. logger.exception("详细错误信息:")
  203. raise
  204. def print_summary(model_outputs):
  205. """打印执行结果摘要"""
  206. print("\n" + "=" * 60)
  207. print("[SUCCESS] Cd预测集成系统执行完成! (架构简化版本)")
  208. print("=" * 60)
  209. print(f"[INFO] 总输出目录: {config.OUTPUT_PATHS}")
  210. print()
  211. # 显示每个模型的输出文件
  212. for model_name, outputs in model_outputs.items():
  213. if model_name == 'crop_cd':
  214. display_name = "作物Cd模型"
  215. elif model_name == 'effective_cd':
  216. display_name = "有效态Cd模型"
  217. else:
  218. display_name = f"{model_name}模型"
  219. print(f"[MODEL] {display_name}:")
  220. for output_type, file_path in outputs.items():
  221. if output_type == 'raster':
  222. print(f" [RASTER] 栅格文件: {file_path}")
  223. elif output_type == 'map':
  224. print(f" [MAP] 地图文件: {file_path}")
  225. elif output_type == 'histogram':
  226. print(f" [CHART] 直方图文件: {file_path}")
  227. print()
  228. print(f"[LOG] 日志文件目录: {config.OUTPUT_PATHS['reports_dir']}")
  229. print("=" * 60)
  230. if __name__ == "__main__":
  231. main()