Browse Source

更新工作流配置管理,支持从环境变量获取配置;删除不再使用的setup_data.py脚本;更新相关代码以适应新配置方式。

drggboy 1 week ago
parent
commit
0599214c90

+ 1 - 0
.gitignore

@@ -13,3 +13,4 @@ myenv/
 .vscode/
 *.log
 Cd_Prediction_Integrated_System/output/raster/meanTemp.tif.aux.xml
+config.env

+ 3 - 21
Cd_Prediction_Integrated_System/analysis/data_processing.py

@@ -88,29 +88,11 @@ class DataProcessor:
         @return: 当前的工作流配置字典
         """
         try:
-            config_file = os.path.join(config.PROJECT_ROOT, "config.py")
-            
-            # 读取配置文件内容
-            with open(config_file, 'r', encoding='utf-8') as f:
-                config_content = f.read()
-            
-            # 提取WORKFLOW_CONFIG
-            import re
-            pattern = r'WORKFLOW_CONFIG\s*=\s*(\{[^}]*\})'
-            match = re.search(pattern, config_content)
-            
-            if match:
-                # 使用eval安全地解析配置(这里是安全的,因为我们控制配置文件内容)
-                workflow_config_str = match.group(1)
-                workflow_config = eval(workflow_config_str)
-                return workflow_config
-            else:
-                self.logger.warning("无法从配置文件中提取WORKFLOW_CONFIG,使用默认配置")
-                return config.WORKFLOW_CONFIG
-                
+            # 使用新的配置获取方法,优先从环境变量获取
+            return config.get_workflow_config()
         except Exception as e:
             self.logger.error(f"读取工作流配置失败: {str(e)},使用默认配置")
-            return config.WORKFLOW_CONFIG
+            return config._DEFAULT_WORKFLOW_CONFIG.copy()
     
     def load_coordinates(self):
         """

+ 20 - 1
Cd_Prediction_Integrated_System/config.py

@@ -74,7 +74,26 @@ VISUALIZATION_CONFIG = {
 }
 
 # 执行流程配置
-WORKFLOW_CONFIG = {'run_crop_model': False, 'run_effective_model': True, 'combine_predictions': True, 'generate_raster': True, 'create_visualization': True, 'create_histogram': True}
+_DEFAULT_WORKFLOW_CONFIG = {'run_crop_model': True, 'run_effective_model': False, 'combine_predictions': True, 'generate_raster': True, 'create_visualization': True, 'create_histogram': True}
+
+def get_workflow_config():
+    """
+    获取工作流配置
+    优先从环境变量获取,如果没有则使用默认配置
+    
+    @returns {dict} 工作流配置字典
+    """
+    import json
+    env_config = os.environ.get('CD_WORKFLOW_CONFIG')
+    if env_config:
+        try:
+            return json.loads(env_config)
+        except json.JSONDecodeError:
+            pass
+    return _DEFAULT_WORKFLOW_CONFIG.copy()
+
+# 为了向后兼容,保留WORKFLOW_CONFIG变量(使用默认配置)
+WORKFLOW_CONFIG = _DEFAULT_WORKFLOW_CONFIG.copy()
 
 def ensure_directories():
     """确保所有必要的目录存在"""

+ 12 - 8
Cd_Prediction_Integrated_System/main.py

@@ -45,22 +45,26 @@ def main():
         config.ensure_directories()
         logger.info("项目目录结构检查完成")
         
+        # 获取当前工作流配置(支持运行时动态配置)
+        workflow_config = config.get_workflow_config()
+        logger.info(f"当前工作流配置: {workflow_config}")
+        
         # 步骤1: 运行作物Cd模型预测
-        if config.WORKFLOW_CONFIG["run_crop_model"]:
+        if workflow_config["run_crop_model"]:
             logger.info("步骤1: 运行作物Cd模型预测...")
             crop_predictor = CropCdPredictor()
             crop_output = crop_predictor.predict()
             logger.info(f"作物Cd模型预测完成,输出文件: {crop_output}")
         
         # 步骤2: 运行有效态Cd模型预测  
-        if config.WORKFLOW_CONFIG["run_effective_model"]:
+        if workflow_config["run_effective_model"]:
             logger.info("步骤2: 运行有效态Cd模型预测...")
             effective_predictor = EffectiveCdPredictor()
             effective_output = effective_predictor.predict()
             logger.info(f"有效态Cd模型预测完成,输出文件: {effective_output}")
         
         # 步骤3: 数据整合处理
-        if config.WORKFLOW_CONFIG["combine_predictions"]:
+        if workflow_config["combine_predictions"]:
             logger.info("步骤3: 整合预测结果与坐标数据...")
             data_processor = DataProcessor()
             final_data_files = data_processor.combine_predictions_with_coordinates()
@@ -77,10 +81,10 @@ def main():
                 continue
             
             # 只为当前工作流配置中启用的模型生成输出文件
-            if model_name == 'crop_cd' and not config.WORKFLOW_CONFIG.get("run_crop_model", False):
+            if model_name == 'crop_cd' and not workflow_config.get("run_crop_model", False):
                 logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
                 continue
-            if model_name == 'effective_cd' and not config.WORKFLOW_CONFIG.get("run_effective_model", False):
+            if model_name == 'effective_cd' and not workflow_config.get("run_effective_model", False):
                 logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
                 continue
                 
@@ -100,7 +104,7 @@ def main():
             model_outputs[model_name] = {}
             
             # 步骤4a: 生成栅格文件(直接调用通用绘图模块)
-            if config.WORKFLOW_CONFIG["generate_raster"]:
+            if workflow_config["generate_raster"]:
                 logger.info(f"步骤4a: 为{display_name}模型生成栅格文件...")
                 
                 # 使用csv_to_raster_workflow直接生成栅格
@@ -156,7 +160,7 @@ def main():
                     raise
             
             # 步骤4b: 创建可视化地图(直接调用通用绘图模块)
-            if config.WORKFLOW_CONFIG["create_visualization"]:
+            if workflow_config["create_visualization"]:
                 logger.info(f"步骤4b: 为{display_name}模型创建可视化地图...")
                 
                 # 为每个模型创建独立的地图输出路径
@@ -188,7 +192,7 @@ def main():
                     raise
             
             # 步骤4c: 创建直方图(直接调用通用绘图模块)
-            if config.WORKFLOW_CONFIG["create_histogram"]:
+            if workflow_config["create_histogram"]:
                 logger.info(f"步骤4c: 为{display_name}模型创建预测值分布直方图...")
                 
                 # 为每个模型创建独立的直方图输出路径

+ 0 - 236
Cd_Prediction_Integrated_System/setup_data.py

@@ -1,236 +0,0 @@
-"""
-数据迁移脚本
-Data Migration Script
-
-用于将现有的三个项目文件夹中的数据和模型文件复制到新的集成项目结构中
-"""
-
-import os
-import shutil
-import logging
-from pathlib import Path
-
-def setup_logging():
-    """设置日志"""
-    logging.basicConfig(
-        level=logging.INFO,
-        format='%(asctime)s - %(levelname)s - %(message)s',
-        handlers=[
-            logging.FileHandler('setup_data.log', encoding='utf-8'),
-            logging.StreamHandler()
-        ]
-    )
-
-def copy_file_safe(src, dst):
-    """
-    安全复制文件
-    
-    @param src: 源文件路径
-    @param dst: 目标文件路径
-    """
-    try:
-        # 确保目标目录存在
-        os.makedirs(os.path.dirname(dst), exist_ok=True)
-        
-        if os.path.exists(src):
-            shutil.copy2(src, dst)
-            logging.info(f"复制成功: {src} -> {dst}")
-            return True
-        else:
-            logging.warning(f"源文件不存在: {src}")
-            return False
-    except Exception as e:
-        logging.error(f"复制失败: {src} -> {dst}, 错误: {str(e)}")
-        return False
-
-def setup_crop_cd_model():
-    """设置作物Cd模型文件"""
-    logging.info("=" * 50)
-    logging.info("设置作物Cd模型文件...")
-    
-    # 源目录
-    src_dir = "../作物Cd模型文件与数据/作物Cd模型文件与数据"
-    
-    # 目标目录
-    model_files_dir = "models/crop_cd_model/model_files"
-    data_dir = "models/crop_cd_model/data"
-    
-    # 复制模型文件
-    model_files = [
-        "cropCdNN.pth",
-        "cropCd_mean.npy",
-        "cropCd_scale.npy",
-        "constrained_nn6.py"
-    ]
-    
-    for file in model_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join(model_files_dir, file)
-        copy_file_safe(src, dst)
-    
-    # 复制数据文件
-    data_files = [
-        "areatest.csv"
-    ]
-    
-    for file in data_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join(data_dir, file)
-        copy_file_safe(src, dst)
-    
-    # 复制坐标文件到共享数据目录
-    coord_src = os.path.join(src_dir, "坐标.csv")
-    coord_dst = "data/coordinates/坐标.csv"
-    copy_file_safe(coord_src, coord_dst)
-
-def setup_effective_cd_model():
-    """设置有效态Cd模型文件"""
-    logging.info("=" * 50)
-    logging.info("设置有效态Cd模型文件...")
-    
-    # 源目录
-    src_dir = "../有效态Cd模型文件与数据/有效态Cd模型文件与数据"
-    
-    # 目标目录
-    model_files_dir = "models/effective_cd_model/model_files"
-    data_dir = "models/effective_cd_model/data"
-    
-    # 复制模型文件
-    model_files = [
-        "EffCdNN6C.pth",
-        "EffCd_mean.npy",
-        "EffCd_scale.npy",
-        "constrained_nn6C.py"
-    ]
-    
-    for file in model_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join(model_files_dir, file)
-        copy_file_safe(src, dst)
-    
-    # 复制数据文件
-    data_files = [
-        "areatest.csv"
-    ]
-    
-    for file in data_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join(data_dir, file)
-        copy_file_safe(src, dst)
-
-def setup_irrigation_water_files():
-    """设置灌溉水项目文件"""
-    logging.info("=" * 50)
-    logging.info("设置灌溉水项目文件...")
-    
-    # 源目录
-    src_dir = "../Irrigation_Water/Irrigation_Water"
-    
-    # 复制栅格文件
-    raster_files = [
-        "Raster/meanTemp.tif",
-        "Raster/lechang.shp",
-        "Raster/lechang.shx",
-        "Raster/lechang.dbf",
-        "Raster/lechang.prj"
-    ]
-    
-    for file in raster_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join("output/raster", os.path.basename(file))
-        copy_file_safe(src, dst)
-    
-    # 复制示例数据文件
-    data_files = [
-        "Data/Final_predictions.csv"
-    ]
-    
-    for file in data_files:
-        src = os.path.join(src_dir, file)
-        dst = os.path.join("data/final", os.path.basename(file))
-        copy_file_safe(src, dst)
-
-def create_directory_structure():
-    """创建目录结构"""
-    logging.info("=" * 50)
-    logging.info("创建目录结构...")
-    
-    directories = [
-        "models/crop_cd_model/model_files",
-        "models/crop_cd_model/data",
-        "models/effective_cd_model/model_files", 
-        "models/effective_cd_model/data",
-        "data/coordinates",
-        "data/predictions",
-        "data/final",
-        "output/raster",
-        "output/figures",
-        "output/reports",
-        "analysis",
-        "utils"
-    ]
-    
-    for directory in directories:
-        os.makedirs(directory, exist_ok=True)
-        logging.info(f"创建目录: {directory}")
-
-def create_requirements_txt():
-    """创建requirements.txt文件"""
-    logging.info("=" * 50)
-    logging.info("创建requirements.txt文件...")
-    
-    requirements = """# Cd预测集成系统依赖包
-numpy>=1.21.0
-pandas>=1.3.0
-torch>=1.9.0
-scikit-learn>=1.0.0
-geopandas>=0.10.0
-rasterio>=1.2.0
-matplotlib>=3.4.0
-seaborn>=0.11.0
-shapely>=1.7.0
-"""
-    
-    with open("requirements.txt", "w", encoding="utf-8") as f:
-        f.write(requirements)
-    
-    logging.info("requirements.txt 创建完成")
-
-def main():
-    """主函数"""
-    setup_logging()
-    
-    logging.info("开始设置Cd预测集成系统...")
-    logging.info("=" * 60)
-    
-    try:
-        # 创建目录结构
-        create_directory_structure()
-        
-        # 设置各个模型的文件
-        setup_crop_cd_model()
-        setup_effective_cd_model()
-        setup_irrigation_water_files()
-        
-        # 创建requirements.txt
-        create_requirements_txt()
-        
-        logging.info("=" * 60)
-        logging.info("🎉 Cd预测集成系统设置完成!")
-        logging.info("=" * 60)
-        
-        print("\n" + "=" * 60)
-        print("🎉 数据迁移完成!")
-        print("=" * 60)
-        print("下一步操作:")
-        print("1. 安装依赖包:pip install -r requirements.txt")
-        print("2. 运行主程序:python main.py")
-        print("3. 查看输出结果在 output/ 目录中")
-        print("=" * 60)
-        
-    except Exception as e:
-        logging.error(f"设置过程中发生错误: {str(e)}")
-        raise
-
-if __name__ == "__main__":
-    main() 

+ 13 - 22
app/utils/cd_prediction_wrapper.py

@@ -75,8 +75,10 @@ class CdPredictionWrapper:
             os.chdir(self.cd_system_path)
             
             try:
-                # 修改配置文件以只运行指定模型
-                self._modify_workflow_config(model_type)
+                # 通过环境变量传递工作流配置,避免修改配置文件
+                workflow_config = self._get_workflow_config(model_type)
+                import json
+                os.environ['CD_WORKFLOW_CONFIG'] = json.dumps(workflow_config)
                 
                 # 运行主脚本
                 result = subprocess.run(
@@ -107,6 +109,10 @@ class CdPredictionWrapper:
                 # 恢复原始工作目录
                 os.chdir(original_cwd)
                 
+                # 清理环境变量
+                if 'CD_WORKFLOW_CONFIG' in os.environ:
+                    del os.environ['CD_WORKFLOW_CONFIG']
+                
         except subprocess.TimeoutExpired:
             self.logger.error("Cd预测脚本执行超时")
             raise Exception("预测脚本执行超时")
@@ -114,14 +120,13 @@ class CdPredictionWrapper:
             self.logger.error(f"运行Cd预测脚本失败: {str(e)}")
             raise
     
-    def _modify_workflow_config(self, model_type: str):
+    def _get_workflow_config(self, model_type: str) -> dict:
         """
-        修改工作流配置
+        获取工作流配置
         
         @param {str} model_type - 模型类型
+        @returns {dict} 工作流配置字典
         """
-        config_file = os.path.join(self.cd_system_path, "config.py")
-        
         # 根据模型类型设置配置
         if model_type == "crop":
             workflow_config = {
@@ -151,22 +156,8 @@ class CdPredictionWrapper:
                 "create_histogram": True
             }
         
-        # 读取当前配置文件
-        with open(config_file, 'r', encoding='utf-8') as f:
-            config_content = f.read()
-        
-        # 替换 WORKFLOW_CONFIG
-        import re
-        pattern = r'WORKFLOW_CONFIG\s*=\s*\{[^}]*\}'
-        replacement = f"WORKFLOW_CONFIG = {workflow_config}"
-        
-        new_content = re.sub(pattern, replacement, config_content, flags=re.MULTILINE | re.DOTALL)
-        
-        # 写回文件
-        with open(config_file, 'w', encoding='utf-8') as f:
-            f.write(new_content)
-        
-        self.logger.info(f"已更新工作流配置为模型类型: {model_type}")
+        self.logger.info(f"生成工作流配置,模型类型: {model_type}")
+        return workflow_config
     
     def _get_output_files(self, model_type: str) -> Dict[str, Any]:
         """

+ 0 - 5
config.env

@@ -1,5 +0,0 @@
-DB_HOST=localhost
-DB_PORT=5432
-DB_NAME=data_db
-DB_USER=postgres
-DB_PASSWORD=scau2025