cd_prediction_wrapper.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. Cd预测系统包装器
  3. @description: 封装对原始Cd预测系统的调用,简化集成过程
  4. @author: AcidMap Team
  5. @version: 1.0.0
  6. """
  7. import os
  8. import sys
  9. import logging
  10. import subprocess
  11. import time
  12. from typing import Dict, Any, Optional
  13. from datetime import datetime
  14. class CdPredictionWrapper:
  15. """
  16. Cd预测系统包装器类
  17. @description: 简化对原始Cd预测系统的调用和集成
  18. @example
  19. >>> wrapper = CdPredictionWrapper()
  20. >>> result = wrapper.run_crop_cd_prediction()
  21. """
  22. def __init__(self, cd_system_path: str):
  23. """
  24. 初始化包装器
  25. @param {str} cd_system_path - Cd预测系统的路径
  26. """
  27. self.cd_system_path = cd_system_path
  28. self.logger = logging.getLogger(__name__)
  29. # 验证Cd预测系统是否存在
  30. if not os.path.exists(cd_system_path):
  31. raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}")
  32. # 检查关键文件是否存在
  33. self._validate_cd_system()
  34. def _validate_cd_system(self):
  35. """
  36. 验证Cd预测系统的完整性
  37. @throws {FileNotFoundError} 当关键文件缺失时抛出
  38. """
  39. required_files = [
  40. "main.py",
  41. "config.py",
  42. "models",
  43. "analysis"
  44. ]
  45. for file_path in required_files:
  46. full_path = os.path.join(self.cd_system_path, file_path)
  47. if not os.path.exists(full_path):
  48. raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}")
  49. self.logger.info("Cd预测系统验证通过")
  50. def run_prediction_script(self, model_type: str = "both", raster_config_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
  51. """
  52. 运行Cd预测脚本
  53. @param {str} model_type - 模型类型 ("crop", "effective", "both")
  54. @returns {Dict[str, Any]} 预测结果信息
  55. @throws {Exception} 当预测过程失败时抛出
  56. """
  57. try:
  58. self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}")
  59. # 切换到Cd预测系统目录
  60. original_cwd = os.getcwd()
  61. os.chdir(self.cd_system_path)
  62. try:
  63. # 通过环境变量传递工作流配置,避免修改配置文件
  64. workflow_config = self._get_workflow_config(model_type)
  65. import json
  66. os.environ['CD_WORKFLOW_CONFIG'] = json.dumps(workflow_config)
  67. # 如果有栅格配置覆盖参数,也通过环境变量传递
  68. if raster_config_override:
  69. os.environ['CD_RASTER_CONFIG_OVERRIDE'] = json.dumps(raster_config_override)
  70. # 运行主脚本
  71. result = subprocess.run(
  72. [sys.executable, "main.py"],
  73. capture_output=True,
  74. text=True,
  75. timeout=300 # 5分钟超时
  76. )
  77. if result.returncode != 0:
  78. self.logger.error(f"Cd预测脚本执行失败: {result.stderr}")
  79. raise Exception(f"预测脚本执行失败: {result.stderr}")
  80. self.logger.info("Cd预测脚本执行成功")
  81. # 获取输出文件信息
  82. output_info = self._get_output_files(model_type)
  83. return {
  84. "success": True,
  85. "model_type": model_type,
  86. "output_files": output_info,
  87. "stdout": result.stdout,
  88. "timestamp": datetime.now().isoformat()
  89. }
  90. finally:
  91. # 恢复原始工作目录
  92. os.chdir(original_cwd)
  93. # 清理环境变量
  94. if 'CD_WORKFLOW_CONFIG' in os.environ:
  95. del os.environ['CD_WORKFLOW_CONFIG']
  96. if 'CD_RASTER_CONFIG_OVERRIDE' in os.environ:
  97. del os.environ['CD_RASTER_CONFIG_OVERRIDE']
  98. except subprocess.TimeoutExpired:
  99. self.logger.error("Cd预测脚本执行超时")
  100. raise Exception("预测脚本执行超时")
  101. except Exception as e:
  102. self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
  103. raise
  104. def _get_workflow_config(self, model_type: str) -> dict:
  105. """
  106. 获取工作流配置
  107. @param {str} model_type - 模型类型
  108. @returns {dict} 工作流配置字典
  109. """
  110. # 根据模型类型设置配置
  111. if model_type == "crop":
  112. workflow_config = {
  113. "run_crop_model": True,
  114. "run_effective_model": False,
  115. "combine_predictions": True,
  116. "generate_raster": True,
  117. "create_visualization": True,
  118. "create_histogram": True
  119. }
  120. elif model_type == "effective":
  121. workflow_config = {
  122. "run_crop_model": False,
  123. "run_effective_model": True,
  124. "combine_predictions": True,
  125. "generate_raster": True,
  126. "create_visualization": True,
  127. "create_histogram": True
  128. }
  129. else: # both
  130. workflow_config = {
  131. "run_crop_model": True,
  132. "run_effective_model": True,
  133. "combine_predictions": True,
  134. "generate_raster": True,
  135. "create_visualization": True,
  136. "create_histogram": True
  137. }
  138. self.logger.info(f"生成工作流配置,模型类型: {model_type}")
  139. return workflow_config
  140. def _get_output_files(self, model_type: str) -> Dict[str, Any]:
  141. """
  142. 获取输出文件信息
  143. @param {str} model_type - 模型类型
  144. @returns {Dict[str, Any]} 输出文件信息
  145. """
  146. output_dir = os.path.join(self.cd_system_path, "output")
  147. figures_dir = os.path.join(output_dir, "figures")
  148. raster_dir = os.path.join(output_dir, "raster")
  149. output_files = {
  150. "figures_dir": figures_dir,
  151. "raster_dir": raster_dir,
  152. "maps": [],
  153. "histograms": [],
  154. "rasters": []
  155. }
  156. # 查找相关文件
  157. if os.path.exists(figures_dir):
  158. for file in os.listdir(figures_dir):
  159. if file.endswith(('.jpg', '.png')):
  160. file_path = os.path.join(figures_dir, file)
  161. if "Prediction" in file and "results" in file:
  162. output_files["maps"].append(file_path)
  163. elif "frequency" in file.lower() or "histogram" in file.lower():
  164. output_files["histograms"].append(file_path)
  165. if os.path.exists(raster_dir):
  166. for file in os.listdir(raster_dir):
  167. if file.endswith('.tif') and file.startswith('output'):
  168. output_files["rasters"].append(os.path.join(raster_dir, file))
  169. return output_files
  170. def get_latest_outputs(self, output_type: str = "all", model_type: str = None) -> Dict[str, Optional[str]]:
  171. """
  172. 获取最新的输出文件
  173. @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
  174. @param {str} model_type - 模型类型 ("crop", "effective", None为获取所有)
  175. @returns {Dict[str, Optional[str]]} 最新输出文件路径
  176. """
  177. try:
  178. output_dir = os.path.join(self.cd_system_path, "output")
  179. figures_dir = os.path.join(output_dir, "figures")
  180. raster_dir = os.path.join(output_dir, "raster")
  181. latest_files = {}
  182. # 获取最新地图文件
  183. if output_type in ["maps", "all"]:
  184. map_files = []
  185. if os.path.exists(figures_dir):
  186. for file in os.listdir(figures_dir):
  187. if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
  188. file_path = os.path.join(figures_dir, file)
  189. map_files.append(file_path)
  190. latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
  191. # 获取最新直方图文件
  192. if output_type in ["histograms", "all"]:
  193. histogram_files = []
  194. if os.path.exists(figures_dir):
  195. for file in os.listdir(figures_dir):
  196. if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
  197. file_path = os.path.join(figures_dir, file)
  198. histogram_files.append(file_path)
  199. latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
  200. # 获取最新栅格文件
  201. if output_type in ["rasters", "all"]:
  202. raster_files = []
  203. if os.path.exists(raster_dir):
  204. for file in os.listdir(raster_dir):
  205. if file.startswith('output') and file.endswith('.tif'):
  206. file_path = os.path.join(raster_dir, file)
  207. raster_files.append(file_path)
  208. latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
  209. # 添加调试信息
  210. self.logger.info(f"获取最新输出文件 - 模型类型: {model_type}, 输出类型: {output_type}")
  211. for key, value in latest_files.items():
  212. if value:
  213. self.logger.info(f" {key}: {os.path.basename(value)}")
  214. else:
  215. self.logger.warning(f" {key}: 未找到文件")
  216. return latest_files
  217. except Exception as e:
  218. self.logger.error(f"获取最新输出文件失败: {str(e)}")
  219. return {}