cd_prediction_wrapper.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """
  2. Cd预测系统包装器
  3. @description: 封装对原始Cd预测系统的调用,简化集成过程
  4. @author: AcidMap Team
  5. @version: 1.0.0
  6. """
  7. import os
  8. import sys
  9. import logging
  10. import subprocess
  11. from typing import Dict, Any, Optional
  12. from datetime import datetime
  13. class CdPredictionWrapper:
  14. """
  15. Cd预测系统包装器类
  16. @description: 简化对原始Cd预测系统的调用和集成
  17. @example
  18. >>> wrapper = CdPredictionWrapper()
  19. >>> result = wrapper.run_crop_cd_prediction()
  20. """
  21. def __init__(self, cd_system_path: str):
  22. """
  23. 初始化包装器
  24. @param {str} cd_system_path - Cd预测系统的路径
  25. """
  26. self.cd_system_path = cd_system_path
  27. self.logger = logging.getLogger(__name__)
  28. # 验证Cd预测系统是否存在
  29. if not os.path.exists(cd_system_path):
  30. raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}")
  31. # 检查关键文件是否存在
  32. self._validate_cd_system()
  33. def _validate_cd_system(self):
  34. """
  35. 验证Cd预测系统的完整性
  36. @throws {FileNotFoundError} 当关键文件缺失时抛出
  37. """
  38. required_files = [
  39. "main.py",
  40. "config.py",
  41. "models",
  42. "analysis"
  43. ]
  44. for file_path in required_files:
  45. full_path = os.path.join(self.cd_system_path, file_path)
  46. if not os.path.exists(full_path):
  47. raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}")
  48. self.logger.info("Cd预测系统验证通过")
  49. def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]:
  50. """
  51. 运行Cd预测脚本
  52. @param {str} model_type - 模型类型 ("crop", "effective", "both")
  53. @returns {Dict[str, Any]} 预测结果信息
  54. @throws {Exception} 当预测过程失败时抛出
  55. """
  56. try:
  57. self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}")
  58. # 切换到Cd预测系统目录
  59. original_cwd = os.getcwd()
  60. os.chdir(self.cd_system_path)
  61. try:
  62. # 修改配置文件以只运行指定模型
  63. self._modify_workflow_config(model_type)
  64. # 运行主脚本
  65. result = subprocess.run(
  66. [sys.executable, "main.py"],
  67. capture_output=True,
  68. text=True,
  69. timeout=300 # 5分钟超时
  70. )
  71. if result.returncode != 0:
  72. self.logger.error(f"Cd预测脚本执行失败: {result.stderr}")
  73. raise Exception(f"预测脚本执行失败: {result.stderr}")
  74. self.logger.info("Cd预测脚本执行成功")
  75. # 获取输出文件信息
  76. output_info = self._get_output_files(model_type)
  77. return {
  78. "success": True,
  79. "model_type": model_type,
  80. "output_files": output_info,
  81. "stdout": result.stdout,
  82. "timestamp": datetime.now().isoformat()
  83. }
  84. finally:
  85. # 恢复原始工作目录
  86. os.chdir(original_cwd)
  87. except subprocess.TimeoutExpired:
  88. self.logger.error("Cd预测脚本执行超时")
  89. raise Exception("预测脚本执行超时")
  90. except Exception as e:
  91. self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
  92. raise
  93. def _modify_workflow_config(self, model_type: str):
  94. """
  95. 修改工作流配置
  96. @param {str} model_type - 模型类型
  97. """
  98. config_file = os.path.join(self.cd_system_path, "config.py")
  99. # 根据模型类型设置配置
  100. if model_type == "crop":
  101. workflow_config = {
  102. "run_crop_model": True,
  103. "run_effective_model": False,
  104. "combine_predictions": True,
  105. "generate_raster": True,
  106. "create_visualization": True,
  107. "create_histogram": True
  108. }
  109. elif model_type == "effective":
  110. workflow_config = {
  111. "run_crop_model": False,
  112. "run_effective_model": True,
  113. "combine_predictions": True,
  114. "generate_raster": True,
  115. "create_visualization": True,
  116. "create_histogram": True
  117. }
  118. else: # both
  119. workflow_config = {
  120. "run_crop_model": True,
  121. "run_effective_model": True,
  122. "combine_predictions": True,
  123. "generate_raster": True,
  124. "create_visualization": True,
  125. "create_histogram": True
  126. }
  127. # 读取当前配置文件
  128. with open(config_file, 'r', encoding='utf-8') as f:
  129. config_content = f.read()
  130. # 替换 WORKFLOW_CONFIG
  131. import re
  132. pattern = r'WORKFLOW_CONFIG\s*=\s*\{[^}]*\}'
  133. replacement = f"WORKFLOW_CONFIG = {workflow_config}"
  134. new_content = re.sub(pattern, replacement, config_content, flags=re.MULTILINE | re.DOTALL)
  135. # 写回文件
  136. with open(config_file, 'w', encoding='utf-8') as f:
  137. f.write(new_content)
  138. self.logger.info(f"已更新工作流配置为模型类型: {model_type}")
  139. def _get_output_files(self, model_type: str) -> Dict[str, Any]:
  140. """
  141. 获取输出文件信息
  142. @param {str} model_type - 模型类型
  143. @returns {Dict[str, Any]} 输出文件信息
  144. """
  145. output_dir = os.path.join(self.cd_system_path, "output")
  146. figures_dir = os.path.join(output_dir, "figures")
  147. raster_dir = os.path.join(output_dir, "raster")
  148. output_files = {
  149. "figures_dir": figures_dir,
  150. "raster_dir": raster_dir,
  151. "maps": [],
  152. "histograms": [],
  153. "rasters": []
  154. }
  155. # 查找相关文件
  156. if os.path.exists(figures_dir):
  157. for file in os.listdir(figures_dir):
  158. if file.endswith(('.jpg', '.png')):
  159. file_path = os.path.join(figures_dir, file)
  160. if "Prediction" in file and "results" in file:
  161. output_files["maps"].append(file_path)
  162. elif "frequency" in file.lower() or "histogram" in file.lower():
  163. output_files["histograms"].append(file_path)
  164. if os.path.exists(raster_dir):
  165. for file in os.listdir(raster_dir):
  166. if file.endswith('.tif') and file.startswith('output'):
  167. output_files["rasters"].append(os.path.join(raster_dir, file))
  168. return output_files
  169. def get_latest_outputs(self, output_type: str = "all") -> Dict[str, Optional[str]]:
  170. """
  171. 获取最新的输出文件
  172. @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
  173. @returns {Dict[str, Optional[str]]} 最新输出文件路径
  174. """
  175. try:
  176. output_dir = os.path.join(self.cd_system_path, "output")
  177. figures_dir = os.path.join(output_dir, "figures")
  178. raster_dir = os.path.join(output_dir, "raster")
  179. latest_files = {}
  180. # 获取最新地图文件
  181. if output_type in ["maps", "all"]:
  182. map_files = []
  183. if os.path.exists(figures_dir):
  184. for file in os.listdir(figures_dir):
  185. if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
  186. map_files.append(os.path.join(figures_dir, file))
  187. latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
  188. # 获取最新直方图文件
  189. if output_type in ["histograms", "all"]:
  190. histogram_files = []
  191. if os.path.exists(figures_dir):
  192. for file in os.listdir(figures_dir):
  193. if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
  194. histogram_files.append(os.path.join(figures_dir, file))
  195. latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
  196. # 获取最新栅格文件
  197. if output_type in ["rasters", "all"]:
  198. raster_files = []
  199. if os.path.exists(raster_dir):
  200. for file in os.listdir(raster_dir):
  201. if file.startswith('output') and file.endswith('.tif'):
  202. raster_files.append(os.path.join(raster_dir, file))
  203. latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
  204. return latest_files
  205. except Exception as e:
  206. self.logger.error(f"获取最新输出文件失败: {str(e)}")
  207. return {}