cd_prediction_wrapper.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """
  2. Cd预测系统包装器
  3. @description: 封装对原始Cd预测系统的调用,简化集成过程
  4. @author: AcidMap Team
  5. @version: 1.0.0
  6. """
  7. import os
  8. import sys
  9. import logging
  10. import subprocess
  11. import time
  12. from typing import Dict, Any, Optional
  13. from datetime import datetime
  14. class CdPredictionWrapper:
  15. """
  16. Cd预测系统包装器类
  17. @description: 简化对原始Cd预测系统的调用和集成
  18. @example
  19. >>> wrapper = CdPredictionWrapper()
  20. >>> result = wrapper.run_crop_cd_prediction()
  21. """
  22. def __init__(self, cd_system_path: str):
  23. """
  24. 初始化包装器
  25. @param {str} cd_system_path - Cd预测系统的路径
  26. """
  27. self.cd_system_path = cd_system_path
  28. self.logger = logging.getLogger(__name__)
  29. # 验证Cd预测系统是否存在
  30. if not os.path.exists(cd_system_path):
  31. raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}")
  32. # 检查关键文件是否存在
  33. self._validate_cd_system()
  34. def _validate_cd_system(self):
  35. """
  36. 验证Cd预测系统的完整性
  37. @throws {FileNotFoundError} 当关键文件缺失时抛出
  38. """
  39. required_files = [
  40. "main.py",
  41. "config.py",
  42. "models",
  43. "analysis"
  44. ]
  45. for file_path in required_files:
  46. full_path = os.path.join(self.cd_system_path, file_path)
  47. if not os.path.exists(full_path):
  48. raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}")
  49. self.logger.info("Cd预测系统验证通过")
  50. def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]:
  51. """
  52. 运行Cd预测脚本
  53. @param {str} model_type - 模型类型 ("crop", "effective", "both")
  54. @returns {Dict[str, Any]} 预测结果信息
  55. @throws {Exception} 当预测过程失败时抛出
  56. """
  57. try:
  58. self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}")
  59. # 切换到Cd预测系统目录
  60. original_cwd = os.getcwd()
  61. os.chdir(self.cd_system_path)
  62. try:
  63. # 修改配置文件以只运行指定模型
  64. self._modify_workflow_config(model_type)
  65. # 运行主脚本
  66. result = subprocess.run(
  67. [sys.executable, "main.py"],
  68. capture_output=True,
  69. text=True,
  70. timeout=300 # 5分钟超时
  71. )
  72. if result.returncode != 0:
  73. self.logger.error(f"Cd预测脚本执行失败: {result.stderr}")
  74. raise Exception(f"预测脚本执行失败: {result.stderr}")
  75. self.logger.info("Cd预测脚本执行成功")
  76. # 获取输出文件信息
  77. output_info = self._get_output_files(model_type)
  78. return {
  79. "success": True,
  80. "model_type": model_type,
  81. "output_files": output_info,
  82. "stdout": result.stdout,
  83. "timestamp": datetime.now().isoformat()
  84. }
  85. finally:
  86. # 恢复原始工作目录
  87. os.chdir(original_cwd)
  88. except subprocess.TimeoutExpired:
  89. self.logger.error("Cd预测脚本执行超时")
  90. raise Exception("预测脚本执行超时")
  91. except Exception as e:
  92. self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
  93. raise
  94. def _modify_workflow_config(self, model_type: str):
  95. """
  96. 修改工作流配置
  97. @param {str} model_type - 模型类型
  98. """
  99. config_file = os.path.join(self.cd_system_path, "config.py")
  100. # 根据模型类型设置配置
  101. if model_type == "crop":
  102. workflow_config = {
  103. "run_crop_model": True,
  104. "run_effective_model": False,
  105. "combine_predictions": True,
  106. "generate_raster": True,
  107. "create_visualization": True,
  108. "create_histogram": True
  109. }
  110. elif model_type == "effective":
  111. workflow_config = {
  112. "run_crop_model": False,
  113. "run_effective_model": True,
  114. "combine_predictions": True,
  115. "generate_raster": True,
  116. "create_visualization": True,
  117. "create_histogram": True
  118. }
  119. else: # both
  120. workflow_config = {
  121. "run_crop_model": True,
  122. "run_effective_model": True,
  123. "combine_predictions": True,
  124. "generate_raster": True,
  125. "create_visualization": True,
  126. "create_histogram": True
  127. }
  128. # 读取当前配置文件
  129. with open(config_file, 'r', encoding='utf-8') as f:
  130. config_content = f.read()
  131. # 替换 WORKFLOW_CONFIG
  132. import re
  133. pattern = r'WORKFLOW_CONFIG\s*=\s*\{[^}]*\}'
  134. replacement = f"WORKFLOW_CONFIG = {workflow_config}"
  135. new_content = re.sub(pattern, replacement, config_content, flags=re.MULTILINE | re.DOTALL)
  136. # 写回文件
  137. with open(config_file, 'w', encoding='utf-8') as f:
  138. f.write(new_content)
  139. self.logger.info(f"已更新工作流配置为模型类型: {model_type}")
  140. def _get_output_files(self, model_type: str) -> Dict[str, Any]:
  141. """
  142. 获取输出文件信息
  143. @param {str} model_type - 模型类型
  144. @returns {Dict[str, Any]} 输出文件信息
  145. """
  146. output_dir = os.path.join(self.cd_system_path, "output")
  147. figures_dir = os.path.join(output_dir, "figures")
  148. raster_dir = os.path.join(output_dir, "raster")
  149. output_files = {
  150. "figures_dir": figures_dir,
  151. "raster_dir": raster_dir,
  152. "maps": [],
  153. "histograms": [],
  154. "rasters": []
  155. }
  156. # 查找相关文件
  157. if os.path.exists(figures_dir):
  158. for file in os.listdir(figures_dir):
  159. if file.endswith(('.jpg', '.png')):
  160. file_path = os.path.join(figures_dir, file)
  161. if "Prediction" in file and "results" in file:
  162. output_files["maps"].append(file_path)
  163. elif "frequency" in file.lower() or "histogram" in file.lower():
  164. output_files["histograms"].append(file_path)
  165. if os.path.exists(raster_dir):
  166. for file in os.listdir(raster_dir):
  167. if file.endswith('.tif') and file.startswith('output'):
  168. output_files["rasters"].append(os.path.join(raster_dir, file))
  169. return output_files
  170. def get_latest_outputs(self, output_type: str = "all", model_type: str = None) -> Dict[str, Optional[str]]:
  171. """
  172. 获取最新的输出文件
  173. @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
  174. @param {str} model_type - 模型类型 ("crop", "effective", None为获取所有)
  175. @returns {Dict[str, Optional[str]]} 最新输出文件路径
  176. """
  177. try:
  178. output_dir = os.path.join(self.cd_system_path, "output")
  179. figures_dir = os.path.join(output_dir, "figures")
  180. raster_dir = os.path.join(output_dir, "raster")
  181. latest_files = {}
  182. # 获取最新地图文件
  183. if output_type in ["maps", "all"]:
  184. map_files = []
  185. if os.path.exists(figures_dir):
  186. for file in os.listdir(figures_dir):
  187. if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
  188. file_path = os.path.join(figures_dir, file)
  189. map_files.append(file_path)
  190. latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
  191. # 获取最新直方图文件
  192. if output_type in ["histograms", "all"]:
  193. histogram_files = []
  194. if os.path.exists(figures_dir):
  195. for file in os.listdir(figures_dir):
  196. if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
  197. file_path = os.path.join(figures_dir, file)
  198. histogram_files.append(file_path)
  199. latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
  200. # 获取最新栅格文件
  201. if output_type in ["rasters", "all"]:
  202. raster_files = []
  203. if os.path.exists(raster_dir):
  204. for file in os.listdir(raster_dir):
  205. if file.startswith('output') and file.endswith('.tif'):
  206. file_path = os.path.join(raster_dir, file)
  207. raster_files.append(file_path)
  208. latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
  209. # 添加调试信息
  210. self.logger.info(f"获取最新输出文件 - 模型类型: {model_type}, 输出类型: {output_type}")
  211. for key, value in latest_files.items():
  212. if value:
  213. self.logger.info(f" {key}: {os.path.basename(value)}")
  214. else:
  215. self.logger.warning(f" {key}: 未找到文件")
  216. return latest_files
  217. except Exception as e:
  218. self.logger.error(f"获取最新输出文件失败: {str(e)}")
  219. return {}