main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. """
  2. 主执行脚本 - Cd预测集成系统
  3. Main execution script for Cd Prediction Integrated System
  4. Description: 集成运行作物Cd模型、有效态Cd模型和数据分析可视化的完整流程
  5. 架构简化版本:直接调用通用绘图模块,删除中间封装层
  6. """
  7. import os
  8. import sys
  9. import logging
  10. from datetime import datetime
  11. # 添加项目根目录到Python路径
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  13. # 添加上级目录到路径以访问通用绘图模块
  14. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  15. import config
  16. from models.crop_cd_model.predict import CropCdPredictor
  17. from models.effective_cd_model.predict import EffectiveCdPredictor
  18. from analysis.data_processing import DataProcessor
  19. from utils.common import setup_logging, validate_files
  20. # 直接导入通用绘图模块
  21. from app.utils.mapping_utils import csv_to_raster_workflow, MappingUtils
  22. def main():
  23. """
  24. 主执行函数
  25. 执行完整的Cd预测分析流程
  26. """
  27. # 设置日志
  28. log_file = os.path.join(config.OUTPUT_PATHS["reports_dir"],
  29. f"execution_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
  30. setup_logging(log_file)
  31. logger = logging.getLogger(__name__)
  32. logger.info("=" * 60)
  33. logger.info("开始执行Cd预测集成系统 (架构简化版本)")
  34. logger.info("=" * 60)
  35. try:
  36. # 确保目录存在
  37. config.ensure_directories()
  38. logger.info("项目目录结构检查完成")
  39. # 获取当前工作流配置(支持运行时动态配置)
  40. workflow_config = config.get_workflow_config()
  41. logger.info(f"当前工作流配置: {workflow_config}")
  42. # 步骤1: 运行作物Cd模型预测
  43. if workflow_config["run_crop_model"]:
  44. logger.info("步骤1: 运行作物Cd模型预测...")
  45. crop_predictor = CropCdPredictor()
  46. crop_output = crop_predictor.predict()
  47. logger.info(f"作物Cd模型预测完成,输出文件: {crop_output}")
  48. # 步骤2: 运行有效态Cd模型预测
  49. if workflow_config["run_effective_model"]:
  50. logger.info("步骤2: 运行有效态Cd模型预测...")
  51. effective_predictor = EffectiveCdPredictor()
  52. effective_output = effective_predictor.predict()
  53. logger.info(f"有效态Cd模型预测完成,输出文件: {effective_output}")
  54. # 步骤3: 数据整合处理
  55. if workflow_config["combine_predictions"]:
  56. logger.info("步骤3: 整合预测结果与坐标数据...")
  57. data_processor = DataProcessor()
  58. final_data_files = data_processor.combine_predictions_with_coordinates()
  59. logger.info(f"数据整合完成,生成文件: {list(final_data_files.keys())}")
  60. # 初始化通用绘图工具(用于可视化)
  61. mapping_utils = MappingUtils()
  62. # 步骤4: 为每个模型生成栅格文件和可视化图
  63. model_outputs = {}
  64. for model_name, final_data_file in final_data_files.items():
  65. if model_name == 'combined': # 跳过合并文件的可视化
  66. continue
  67. # 只为当前工作流配置中启用的模型生成输出文件
  68. if model_name == 'crop_cd' and not workflow_config.get("run_crop_model", False):
  69. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  70. continue
  71. if model_name == 'effective_cd' and not workflow_config.get("run_effective_model", False):
  72. logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
  73. continue
  74. logger.info(f"为{model_name}模型生成栅格和可视化...")
  75. # 为每个模型确定显示名称
  76. if model_name == 'crop_cd':
  77. display_name = "作物Cd"
  78. title_name = "Crop Cd Prediction"
  79. elif model_name == 'effective_cd':
  80. display_name = "有效态Cd"
  81. title_name = "Effective Cd Prediction"
  82. else:
  83. display_name = model_name
  84. title_name = f"{model_name} Prediction"
  85. model_outputs[model_name] = {}
  86. # 步骤4a: 生成栅格文件(直接调用通用绘图模块)
  87. if workflow_config["generate_raster"]:
  88. logger.info(f"步骤4a: 为{display_name}模型生成栅格文件...")
  89. # 使用csv_to_raster_workflow直接生成栅格
  90. try:
  91. # 获取栅格配置(支持运行时参数覆盖)
  92. raster_config = config.get_raster_config()
  93. workflow_result = csv_to_raster_workflow(
  94. csv_file=final_data_file,
  95. template_tif=config.ANALYSIS_CONFIG["template_tif"],
  96. output_dir=config.OUTPUT_PATHS["raster_dir"],
  97. boundary_shp=config.ANALYSIS_CONFIG.get("boundary_shp"),
  98. resolution_factor=raster_config["resolution_factor"],
  99. interpolation_method=raster_config["interpolation_method"],
  100. field_name=raster_config["field_name"],
  101. lon_col=raster_config["coordinate_columns"]["longitude"],
  102. lat_col=raster_config["coordinate_columns"]["latitude"],
  103. value_col=raster_config["coordinate_columns"]["value"],
  104. enable_interpolation=raster_config["enable_interpolation"]
  105. )
  106. output_raster = workflow_result['raster']
  107. # 重命名为特定模型的文件
  108. model_raster_path = os.path.join(
  109. config.OUTPUT_PATHS["raster_dir"],
  110. f"output_{model_name}.tif"
  111. )
  112. if os.path.exists(output_raster) and output_raster != model_raster_path:
  113. import shutil
  114. shutil.move(output_raster, model_raster_path)
  115. output_raster = model_raster_path
  116. model_outputs[model_name]['raster'] = output_raster
  117. logger.info(f"{display_name}模型栅格文件生成完成: {output_raster}")
  118. # 清理中间shapefile文件
  119. try:
  120. shapefile_path = workflow_result.get('shapefile')
  121. if shapefile_path and os.path.exists(shapefile_path):
  122. # 删除shapefile及其相关文件
  123. import glob
  124. base_path = os.path.splitext(shapefile_path)[0]
  125. for ext in ['.shp', '.shx', '.dbf', '.prj', '.cpg']:
  126. file_to_delete = base_path + ext
  127. if os.path.exists(file_to_delete):
  128. os.remove(file_to_delete)
  129. logger.debug(f"已删除中间文件: {file_to_delete}")
  130. logger.info(f"已清理中间shapefile文件: {shapefile_path}")
  131. except Exception as cleanup_error:
  132. logger.warning(f"清理中间文件失败: {str(cleanup_error)}")
  133. except Exception as e:
  134. logger.error(f"栅格生成失败: {str(e)}")
  135. raise
  136. # 步骤4b: 创建可视化地图(直接调用通用绘图模块)
  137. if workflow_config["create_visualization"]:
  138. logger.info(f"步骤4b: 为{display_name}模型创建可视化地图...")
  139. # 为每个模型创建独立的地图输出路径
  140. map_output_path = os.path.join(
  141. config.OUTPUT_PATHS["figures_dir"],
  142. f"Prediction_results_{model_name}"
  143. )
  144. try:
  145. # 直接调用通用绘图模块的create_raster_map
  146. map_output = mapping_utils.create_raster_map(
  147. shp_path=config.ANALYSIS_CONFIG.get("boundary_shp"),
  148. tif_path=output_raster,
  149. output_path=map_output_path,
  150. title=title_name,
  151. colormap=config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]],
  152. output_size=config.VISUALIZATION_CONFIG["figure_size"],
  153. dpi=config.VISUALIZATION_CONFIG["dpi"],
  154. resolution_factor=1.0, # 保持栅格原始分辨率
  155. enable_interpolation=False, # 不在可视化阶段插值
  156. interpolation_method='nearest'
  157. )
  158. model_outputs[model_name]['map'] = map_output
  159. logger.info(f"{display_name}模型地图可视化完成: {map_output}")
  160. except Exception as e:
  161. logger.error(f"地图可视化失败: {str(e)}")
  162. raise
  163. # 步骤4c: 创建直方图(直接调用通用绘图模块)
  164. if workflow_config["create_histogram"]:
  165. logger.info(f"步骤4c: 为{display_name}模型创建预测值分布直方图...")
  166. # 为每个模型创建独立的直方图输出路径
  167. histogram_output_path = os.path.join(
  168. config.OUTPUT_PATHS["figures_dir"],
  169. f"Prediction_frequency_{model_name}.jpg"
  170. )
  171. # 为了避免中文字体问题,使用英文标题
  172. if model_name == 'crop_cd':
  173. hist_title = 'Crop Cd Prediction Frequency'
  174. hist_xlabel = 'Crop Cd Content (mg/kg)'
  175. elif model_name == 'effective_cd':
  176. hist_title = 'Effective Cd Prediction Frequency'
  177. hist_xlabel = 'Effective Cd Content (mg/kg)'
  178. else:
  179. hist_title = f'{model_name} Prediction Frequency'
  180. hist_xlabel = f'{model_name} Content'
  181. try:
  182. # 直接调用通用绘图模块的create_histogram
  183. histogram_output = mapping_utils.create_histogram(
  184. file_path=output_raster,
  185. save_path=histogram_output_path,
  186. figsize=(6, 6),
  187. xlabel=hist_xlabel,
  188. ylabel='Frequency',
  189. title=hist_title,
  190. dpi=config.VISUALIZATION_CONFIG["dpi"]
  191. )
  192. model_outputs[model_name]['histogram'] = histogram_output
  193. logger.info(f"{display_name}模型直方图创建完成: {histogram_output}")
  194. except Exception as e:
  195. logger.error(f"直方图生成失败: {str(e)}")
  196. raise
  197. logger.info("=" * 60)
  198. logger.info("所有流程执行完成!")
  199. logger.info("=" * 60)
  200. # 打印结果摘要
  201. print_summary(model_outputs)
  202. except Exception as e:
  203. logger.error(f"执行过程中发生错误: {str(e)}")
  204. logger.exception("详细错误信息:")
  205. raise
  206. def print_summary(model_outputs):
  207. """打印执行结果摘要"""
  208. print("\n" + "=" * 60)
  209. print("[SUCCESS] Cd预测集成系统执行完成! (架构简化版本)")
  210. print("=" * 60)
  211. print(f"[INFO] 总输出目录: {config.OUTPUT_PATHS}")
  212. print()
  213. # 显示每个模型的输出文件
  214. for model_name, outputs in model_outputs.items():
  215. if model_name == 'crop_cd':
  216. display_name = "作物Cd模型"
  217. elif model_name == 'effective_cd':
  218. display_name = "有效态Cd模型"
  219. else:
  220. display_name = f"{model_name}模型"
  221. print(f"[MODEL] {display_name}:")
  222. for output_type, file_path in outputs.items():
  223. if output_type == 'raster':
  224. print(f" [RASTER] 栅格文件: {file_path}")
  225. elif output_type == 'map':
  226. print(f" [MAP] 地图文件: {file_path}")
  227. elif output_type == 'histogram':
  228. print(f" [CHART] 直方图文件: {file_path}")
  229. print()
  230. print(f"[LOG] 日志文件目录: {config.OUTPUT_PATHS['reports_dir']}")
  231. print("=" * 60)
  232. if __name__ == "__main__":
  233. main()