cd_prediction_wrapper.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. """
  2. Cd预测系统包装器
  3. @description: 封装对原始Cd预测系统的调用,简化集成过程
  4. @author: AcidMap Team
  5. @version: 1.0.0
  6. """
  7. import os
  8. import sys
  9. import logging
  10. import subprocess
  11. import time
  12. from typing import Dict, Any, Optional
  13. from datetime import datetime
  14. class CdPredictionWrapper:
  15. """
  16. Cd预测系统包装器类
  17. @description: 简化对原始Cd预测系统的调用和集成
  18. @example
  19. >>> wrapper = CdPredictionWrapper()
  20. >>> result = wrapper.run_crop_cd_prediction()
  21. """
  22. def __init__(self, cd_system_path: str):
  23. """
  24. 初始化包装器
  25. @param {str} cd_system_path - Cd预测系统的路径
  26. """
  27. self.cd_system_path = cd_system_path
  28. self.logger = logging.getLogger(__name__)
  29. # 验证Cd预测系统是否存在
  30. if not os.path.exists(cd_system_path):
  31. raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}")
  32. # 检查关键文件是否存在
  33. self._validate_cd_system()
  34. def _validate_cd_system(self):
  35. """
  36. 验证Cd预测系统的完整性
  37. @throws {FileNotFoundError} 当关键文件缺失时抛出
  38. """
  39. required_files = [
  40. "main.py",
  41. "config.py",
  42. "models",
  43. "analysis"
  44. ]
  45. for file_path in required_files:
  46. full_path = os.path.join(self.cd_system_path, file_path)
  47. if not os.path.exists(full_path):
  48. raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}")
  49. self.logger.info("Cd预测系统验证通过")
  50. def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]:
  51. """
  52. 运行Cd预测脚本
  53. @param {str} model_type - 模型类型 ("crop", "effective", "both")
  54. @returns {Dict[str, Any]} 预测结果信息
  55. @throws {Exception} 当预测过程失败时抛出
  56. """
  57. try:
  58. self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}")
  59. # 切换到Cd预测系统目录
  60. original_cwd = os.getcwd()
  61. os.chdir(self.cd_system_path)
  62. try:
  63. # 通过环境变量传递工作流配置,避免修改配置文件
  64. workflow_config = self._get_workflow_config(model_type)
  65. import json
  66. os.environ['CD_WORKFLOW_CONFIG'] = json.dumps(workflow_config)
  67. # 运行主脚本
  68. result = subprocess.run(
  69. [sys.executable, "main.py"],
  70. capture_output=True,
  71. text=True,
  72. timeout=300 # 5分钟超时
  73. )
  74. if result.returncode != 0:
  75. self.logger.error(f"Cd预测脚本执行失败: {result.stderr}")
  76. raise Exception(f"预测脚本执行失败: {result.stderr}")
  77. self.logger.info("Cd预测脚本执行成功")
  78. # 获取输出文件信息
  79. output_info = self._get_output_files(model_type)
  80. return {
  81. "success": True,
  82. "model_type": model_type,
  83. "output_files": output_info,
  84. "stdout": result.stdout,
  85. "timestamp": datetime.now().isoformat()
  86. }
  87. finally:
  88. # 恢复原始工作目录
  89. os.chdir(original_cwd)
  90. # 清理环境变量
  91. if 'CD_WORKFLOW_CONFIG' in os.environ:
  92. del os.environ['CD_WORKFLOW_CONFIG']
  93. except subprocess.TimeoutExpired:
  94. self.logger.error("Cd预测脚本执行超时")
  95. raise Exception("预测脚本执行超时")
  96. except Exception as e:
  97. self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
  98. raise
  99. def _get_workflow_config(self, model_type: str) -> dict:
  100. """
  101. 获取工作流配置
  102. @param {str} model_type - 模型类型
  103. @returns {dict} 工作流配置字典
  104. """
  105. # 根据模型类型设置配置
  106. if model_type == "crop":
  107. workflow_config = {
  108. "run_crop_model": True,
  109. "run_effective_model": False,
  110. "combine_predictions": True,
  111. "generate_raster": True,
  112. "create_visualization": True,
  113. "create_histogram": True
  114. }
  115. elif model_type == "effective":
  116. workflow_config = {
  117. "run_crop_model": False,
  118. "run_effective_model": True,
  119. "combine_predictions": True,
  120. "generate_raster": True,
  121. "create_visualization": True,
  122. "create_histogram": True
  123. }
  124. else: # both
  125. workflow_config = {
  126. "run_crop_model": True,
  127. "run_effective_model": True,
  128. "combine_predictions": True,
  129. "generate_raster": True,
  130. "create_visualization": True,
  131. "create_histogram": True
  132. }
  133. self.logger.info(f"生成工作流配置,模型类型: {model_type}")
  134. return workflow_config
  135. def _get_output_files(self, model_type: str) -> Dict[str, Any]:
  136. """
  137. 获取输出文件信息
  138. @param {str} model_type - 模型类型
  139. @returns {Dict[str, Any]} 输出文件信息
  140. """
  141. output_dir = os.path.join(self.cd_system_path, "output")
  142. figures_dir = os.path.join(output_dir, "figures")
  143. raster_dir = os.path.join(output_dir, "raster")
  144. output_files = {
  145. "figures_dir": figures_dir,
  146. "raster_dir": raster_dir,
  147. "maps": [],
  148. "histograms": [],
  149. "rasters": []
  150. }
  151. # 查找相关文件
  152. if os.path.exists(figures_dir):
  153. for file in os.listdir(figures_dir):
  154. if file.endswith(('.jpg', '.png')):
  155. file_path = os.path.join(figures_dir, file)
  156. if "Prediction" in file and "results" in file:
  157. output_files["maps"].append(file_path)
  158. elif "frequency" in file.lower() or "histogram" in file.lower():
  159. output_files["histograms"].append(file_path)
  160. if os.path.exists(raster_dir):
  161. for file in os.listdir(raster_dir):
  162. if file.endswith('.tif') and file.startswith('output'):
  163. output_files["rasters"].append(os.path.join(raster_dir, file))
  164. return output_files
  165. def get_latest_outputs(self, output_type: str = "all", model_type: str = None) -> Dict[str, Optional[str]]:
  166. """
  167. 获取最新的输出文件
  168. @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
  169. @param {str} model_type - 模型类型 ("crop", "effective", None为获取所有)
  170. @returns {Dict[str, Optional[str]]} 最新输出文件路径
  171. """
  172. try:
  173. output_dir = os.path.join(self.cd_system_path, "output")
  174. figures_dir = os.path.join(output_dir, "figures")
  175. raster_dir = os.path.join(output_dir, "raster")
  176. latest_files = {}
  177. # 获取最新地图文件
  178. if output_type in ["maps", "all"]:
  179. map_files = []
  180. if os.path.exists(figures_dir):
  181. for file in os.listdir(figures_dir):
  182. if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
  183. file_path = os.path.join(figures_dir, file)
  184. map_files.append(file_path)
  185. latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
  186. # 获取最新直方图文件
  187. if output_type in ["histograms", "all"]:
  188. histogram_files = []
  189. if os.path.exists(figures_dir):
  190. for file in os.listdir(figures_dir):
  191. if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
  192. file_path = os.path.join(figures_dir, file)
  193. histogram_files.append(file_path)
  194. latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
  195. # 获取最新栅格文件
  196. if output_type in ["rasters", "all"]:
  197. raster_files = []
  198. if os.path.exists(raster_dir):
  199. for file in os.listdir(raster_dir):
  200. if file.startswith('output') and file.endswith('.tif'):
  201. file_path = os.path.join(raster_dir, file)
  202. raster_files.append(file_path)
  203. latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
  204. # 添加调试信息
  205. self.logger.info(f"获取最新输出文件 - 模型类型: {model_type}, 输出类型: {output_type}")
  206. for key, value in latest_files.items():
  207. if value:
  208. self.logger.info(f" {key}: {os.path.basename(value)}")
  209. else:
  210. self.logger.warning(f" {key}: 未找到文件")
  211. return latest_files
  212. except Exception as e:
  213. self.logger.error(f"获取最新输出文件失败: {str(e)}")
  214. return {}