123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- """
- Cd预测系统包装器
- @description: 封装对原始Cd预测系统的调用,简化集成过程
- @author: AcidMap Team
- @version: 1.0.0
- """
- import os
- import sys
- import logging
- import subprocess
- from typing import Dict, Any, Optional
- from datetime import datetime
- class CdPredictionWrapper:
- """
- Cd预测系统包装器类
-
- @description: 简化对原始Cd预测系统的调用和集成
- @example
- >>> wrapper = CdPredictionWrapper()
- >>> result = wrapper.run_crop_cd_prediction()
- """
-
- def __init__(self, cd_system_path: str):
- """
- 初始化包装器
-
- @param {str} cd_system_path - Cd预测系统的路径
- """
- self.cd_system_path = cd_system_path
- self.logger = logging.getLogger(__name__)
-
- # 验证Cd预测系统是否存在
- if not os.path.exists(cd_system_path):
- raise FileNotFoundError(f"Cd预测系统路径不存在: {cd_system_path}")
-
- # 检查关键文件是否存在
- self._validate_cd_system()
-
- def _validate_cd_system(self):
- """
- 验证Cd预测系统的完整性
-
- @throws {FileNotFoundError} 当关键文件缺失时抛出
- """
- required_files = [
- "main.py",
- "config.py",
- "models",
- "analysis"
- ]
-
- for file_path in required_files:
- full_path = os.path.join(self.cd_system_path, file_path)
- if not os.path.exists(full_path):
- raise FileNotFoundError(f"Cd预测系统缺少必要文件: {file_path}")
-
- self.logger.info("Cd预测系统验证通过")
-
- def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]:
- """
- 运行Cd预测脚本
-
- @param {str} model_type - 模型类型 ("crop", "effective", "both")
- @returns {Dict[str, Any]} 预测结果信息
- @throws {Exception} 当预测过程失败时抛出
- """
- try:
- self.logger.info(f"开始运行Cd预测脚本,模型类型: {model_type}")
-
- # 切换到Cd预测系统目录
- original_cwd = os.getcwd()
- os.chdir(self.cd_system_path)
-
- try:
- # 修改配置文件以只运行指定模型
- self._modify_workflow_config(model_type)
-
- # 运行主脚本
- result = subprocess.run(
- [sys.executable, "main.py"],
- capture_output=True,
- text=True,
- timeout=300 # 5分钟超时
- )
-
- if result.returncode != 0:
- self.logger.error(f"Cd预测脚本执行失败: {result.stderr}")
- raise Exception(f"预测脚本执行失败: {result.stderr}")
-
- self.logger.info("Cd预测脚本执行成功")
-
- # 获取输出文件信息
- output_info = self._get_output_files(model_type)
-
- return {
- "success": True,
- "model_type": model_type,
- "output_files": output_info,
- "stdout": result.stdout,
- "timestamp": datetime.now().isoformat()
- }
-
- finally:
- # 恢复原始工作目录
- os.chdir(original_cwd)
-
- except subprocess.TimeoutExpired:
- self.logger.error("Cd预测脚本执行超时")
- raise Exception("预测脚本执行超时")
- except Exception as e:
- self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
- raise
-
- def _modify_workflow_config(self, model_type: str):
- """
- 修改工作流配置
-
- @param {str} model_type - 模型类型
- """
- config_file = os.path.join(self.cd_system_path, "config.py")
-
- # 根据模型类型设置配置
- if model_type == "crop":
- workflow_config = {
- "run_crop_model": True,
- "run_effective_model": False,
- "combine_predictions": True,
- "generate_raster": True,
- "create_visualization": True,
- "create_histogram": True
- }
- elif model_type == "effective":
- workflow_config = {
- "run_crop_model": False,
- "run_effective_model": True,
- "combine_predictions": True,
- "generate_raster": True,
- "create_visualization": True,
- "create_histogram": True
- }
- else: # both
- workflow_config = {
- "run_crop_model": True,
- "run_effective_model": True,
- "combine_predictions": True,
- "generate_raster": True,
- "create_visualization": True,
- "create_histogram": True
- }
-
- # 读取当前配置文件
- with open(config_file, 'r', encoding='utf-8') as f:
- config_content = f.read()
-
- # 替换 WORKFLOW_CONFIG
- import re
- pattern = r'WORKFLOW_CONFIG\s*=\s*\{[^}]*\}'
- replacement = f"WORKFLOW_CONFIG = {workflow_config}"
-
- new_content = re.sub(pattern, replacement, config_content, flags=re.MULTILINE | re.DOTALL)
-
- # 写回文件
- with open(config_file, 'w', encoding='utf-8') as f:
- f.write(new_content)
-
- self.logger.info(f"已更新工作流配置为模型类型: {model_type}")
-
- def _get_output_files(self, model_type: str) -> Dict[str, Any]:
- """
- 获取输出文件信息
-
- @param {str} model_type - 模型类型
- @returns {Dict[str, Any]} 输出文件信息
- """
- output_dir = os.path.join(self.cd_system_path, "output")
- figures_dir = os.path.join(output_dir, "figures")
- raster_dir = os.path.join(output_dir, "raster")
-
- output_files = {
- "figures_dir": figures_dir,
- "raster_dir": raster_dir,
- "maps": [],
- "histograms": [],
- "rasters": []
- }
-
- # 查找相关文件
- if os.path.exists(figures_dir):
- for file in os.listdir(figures_dir):
- if file.endswith(('.jpg', '.png')):
- file_path = os.path.join(figures_dir, file)
- if "Prediction" in file and "results" in file:
- output_files["maps"].append(file_path)
- elif "frequency" in file.lower() or "histogram" in file.lower():
- output_files["histograms"].append(file_path)
-
- if os.path.exists(raster_dir):
- for file in os.listdir(raster_dir):
- if file.endswith('.tif') and file.startswith('output'):
- output_files["rasters"].append(os.path.join(raster_dir, file))
-
- return output_files
-
- def get_latest_outputs(self, output_type: str = "all") -> Dict[str, Optional[str]]:
- """
- 获取最新的输出文件
-
- @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
- @returns {Dict[str, Optional[str]]} 最新输出文件路径
- """
- try:
- output_dir = os.path.join(self.cd_system_path, "output")
- figures_dir = os.path.join(output_dir, "figures")
- raster_dir = os.path.join(output_dir, "raster")
-
- latest_files = {}
-
- # 获取最新地图文件
- if output_type in ["maps", "all"]:
- map_files = []
- if os.path.exists(figures_dir):
- for file in os.listdir(figures_dir):
- if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
- map_files.append(os.path.join(figures_dir, file))
-
- latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
-
- # 获取最新直方图文件
- if output_type in ["histograms", "all"]:
- histogram_files = []
- if os.path.exists(figures_dir):
- for file in os.listdir(figures_dir):
- if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
- histogram_files.append(os.path.join(figures_dir, file))
-
- latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
-
- # 获取最新栅格文件
- if output_type in ["rasters", "all"]:
- raster_files = []
- if os.path.exists(raster_dir):
- for file in os.listdir(raster_dir):
- if file.startswith('output') and file.endswith('.tif'):
- raster_files.append(os.path.join(raster_dir, file))
-
- latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
-
- return latest_files
-
- except Exception as e:
- self.logger.error(f"获取最新输出文件失败: {str(e)}")
- return {}
|