""" Cd预测系统包装器 @description: 封装对原始Cd预测系统的调用,简化集成过程 @author: AcidMap Team @version: 1.0.0 """ import os import sys import logging import subprocess from typing import Dict, Any, Optional from datetime import datetime class CdPredictionWrapper: """ Cd预测系统包装器类 @description: 简化对原始Cd预测系统的调用和集成 @example >>> wrapper = CdPredictionWrapper() >>> result = wrapper.run_crop_cd_prediction() """ def __init__(self, cd_system_path: str): """ 初始化包装器 @param {str} cd_system_path - Cd预测系统的路径 """ self.cd_system_path = cd_system_path self.logger = logging.getLogger(__name__) # 验证Cd预测系统是否存在 if not os.path.exists(cd_system_path): raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}") # 检查关键文件是否存在 self._validate_cd_system() def _validate_cd_system(self): """ 验证Cd预测系统的完整性 @throws {FileNotFoundError} 当关键文件缺失时抛出 """ required_files = [ "main.py", "config.py", "models", "analysis" ] for file_path in required_files: full_path = os.path.join(self.cd_system_path, file_path) if not os.path.exists(full_path): raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}") self.logger.info("Cd预测系统验证通过") def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]: """ 运行Cd预测脚本 @param {str} model_type - 模型类型 ("crop", "effective", "both") @returns {Dict[str, Any]} 预测结果信息 @throws {Exception} 当预测过程失败时抛出 """ try: self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}") # 切换到Cd预测系统目录 original_cwd = os.getcwd() os.chdir(self.cd_system_path) try: # 修改配置文件以只运行指定模型 self._modify_workflow_config(model_type) # 运行主脚本 result = subprocess.run( [sys.executable, "main.py"], capture_output=True, text=True, timeout=300 # 5分钟超时 ) if result.returncode != 0: self.logger.error(f"Cd预测脚本执行失败: {result.stderr}") raise Exception(f"预测脚本执行失败: {result.stderr}") self.logger.info("Cd预测脚本执行成功") # 获取输出文件信息 output_info = self._get_output_files(model_type) return { "success": True, "model_type": model_type, "output_files": output_info, "stdout": result.stdout, "timestamp": datetime.now().isoformat() } finally: # 恢复原始工作目录 os.chdir(original_cwd) except subprocess.TimeoutExpired: self.logger.error("Cd预测脚本执行超时") raise Exception("预测脚本执行超时") except Exception as e: self.logger.error(f"运行Cd预测脚本失败: {str(e)}") raise def _modify_workflow_config(self, model_type: str): """ 修改工作流配置 @param {str} model_type - 模型类型 """ config_file = os.path.join(self.cd_system_path, "config.py") # 根据模型类型设置配置 if model_type == "crop": workflow_config = { "run_crop_model": True, "run_effective_model": False, "combine_predictions": True, "generate_raster": True, "create_visualization": True, "create_histogram": True } elif model_type == "effective": workflow_config = { "run_crop_model": False, "run_effective_model": True, "combine_predictions": True, "generate_raster": True, "create_visualization": True, "create_histogram": True } else: # both workflow_config = { "run_crop_model": True, "run_effective_model": True, "combine_predictions": True, "generate_raster": True, "create_visualization": True, "create_histogram": True } # 读取当前配置文件 with open(config_file, 'r', encoding='utf-8') as f: config_content = f.read() # 替换 WORKFLOW_CONFIG import re pattern = r'WORKFLOW_CONFIG\s*=\s*\{[^}]*\}' replacement = f"WORKFLOW_CONFIG = {workflow_config}" new_content = re.sub(pattern, replacement, config_content, flags=re.MULTILINE | re.DOTALL) # 写回文件 with open(config_file, 'w', encoding='utf-8') as f: f.write(new_content) self.logger.info(f"已更新工作流配置为模型类型: {model_type}") def _get_output_files(self, model_type: str) -> Dict[str, Any]: """ 获取输出文件信息 @param {str} model_type - 模型类型 @returns {Dict[str, Any]} 输出文件信息 """ output_dir = os.path.join(self.cd_system_path, "output") figures_dir = os.path.join(output_dir, "figures") raster_dir = os.path.join(output_dir, "raster") output_files = { "figures_dir": figures_dir, "raster_dir": raster_dir, "maps": [], "histograms": [], "rasters": [] } # 查找相关文件 if os.path.exists(figures_dir): for file in os.listdir(figures_dir): if file.endswith(('.jpg', '.png')): file_path = os.path.join(figures_dir, file) if "Prediction" in file and "results" in file: output_files["maps"].append(file_path) elif "frequency" in file.lower() or "histogram" in file.lower(): output_files["histograms"].append(file_path) if os.path.exists(raster_dir): for file in os.listdir(raster_dir): if file.endswith('.tif') and file.startswith('output'): output_files["rasters"].append(os.path.join(raster_dir, file)) return output_files def get_latest_outputs(self, output_type: str = "all") -> Dict[str, Optional[str]]: """ 获取最新的输出文件 @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all") @returns {Dict[str, Optional[str]]} 最新输出文件路径 """ try: output_dir = os.path.join(self.cd_system_path, "output") figures_dir = os.path.join(output_dir, "figures") raster_dir = os.path.join(output_dir, "raster") latest_files = {} # 获取最新地图文件 if output_type in ["maps", "all"]: map_files = [] if os.path.exists(figures_dir): for file in os.listdir(figures_dir): if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')): map_files.append(os.path.join(figures_dir, file)) latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None # 获取最新直方图文件 if output_type in ["histograms", "all"]: histogram_files = [] if os.path.exists(figures_dir): for file in os.listdir(figures_dir): if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')): histogram_files.append(os.path.join(figures_dir, file)) latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None # 获取最新栅格文件 if output_type in ["rasters", "all"]: raster_files = [] if os.path.exists(raster_dir): for file in os.listdir(raster_dir): if file.startswith('output') and file.endswith('.tif'): raster_files.append(os.path.join(raster_dir, file)) latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None return latest_files except Exception as e: self.logger.error(f"获取最新输出文件失败: {str(e)}") return {}