浏览代码

实现分辨率调整;解决调用模型与绘图标题对应不上问题

drggboy 1 天之前
父节点
当前提交
21ea6ab1a0

+ 66 - 15
Cd_Prediction_Integrated_System/analysis/data_processing.py

@@ -30,30 +30,50 @@ class DataProcessor:
         
     def load_predictions(self):
         """
-        加载模型预测结果
+        加载模型预测结果(根据WORKFLOW_CONFIG配置)
         
         @return: 包含预测结果的字典
         """
         try:
             predictions = {}
             
+            # 动态读取当前的工作流配置(运行时可能被修改)
+            workflow_config = self._get_current_workflow_config()
+            self.logger.info(f"当前工作流配置: {workflow_config}")
+            
+            # 只加载在工作流配置中启用的模型的预测结果
             # 加载作物Cd预测结果
-            crop_cd_path = os.path.join(
-                config.DATA_PATHS["predictions_dir"],
-                config.CROP_CD_MODEL["output_file"]
-            )
-            if os.path.exists(crop_cd_path):
-                predictions['crop_cd'] = pd.read_csv(crop_cd_path)
-                self.logger.info(f"作物Cd预测结果加载成功: {crop_cd_path}")
+            if workflow_config.get("run_crop_model", False):
+                crop_cd_path = os.path.join(
+                    config.DATA_PATHS["predictions_dir"],
+                    config.CROP_CD_MODEL["output_file"]
+                )
+                if os.path.exists(crop_cd_path):
+                    predictions['crop_cd'] = pd.read_csv(crop_cd_path)
+                    self.logger.info(f"作物Cd预测结果加载成功: {crop_cd_path}")
+                else:
+                    self.logger.warning(f"作物Cd预测文件不存在: {crop_cd_path}")
+            else:
+                self.logger.info("跳过作物Cd预测结果加载(工作流配置中未启用)")
             
             # 加载有效态Cd预测结果
-            effective_cd_path = os.path.join(
-                config.DATA_PATHS["predictions_dir"],
-                config.EFFECTIVE_CD_MODEL["output_file"]
-            )
-            if os.path.exists(effective_cd_path):
-                predictions['effective_cd'] = pd.read_csv(effective_cd_path)
-                self.logger.info(f"有效态Cd预测结果加载成功: {effective_cd_path}")
+            if workflow_config.get("run_effective_model", False):
+                effective_cd_path = os.path.join(
+                    config.DATA_PATHS["predictions_dir"],
+                    config.EFFECTIVE_CD_MODEL["output_file"]
+                )
+                if os.path.exists(effective_cd_path):
+                    predictions['effective_cd'] = pd.read_csv(effective_cd_path)
+                    self.logger.info(f"有效态Cd预测结果加载成功: {effective_cd_path}")
+                else:
+                    self.logger.warning(f"有效态Cd预测文件不存在: {effective_cd_path}")
+            else:
+                self.logger.info("跳过有效态Cd预测结果加载(工作流配置中未启用)")
+            
+            if not predictions:
+                self.logger.warning("没有加载到任何预测结果,请检查WORKFLOW_CONFIG配置和预测文件是否存在")
+            else:
+                self.logger.info(f"根据工作流配置,成功加载了 {len(predictions)} 个模型的预测结果: {list(predictions.keys())}")
             
             return predictions
             
@@ -61,6 +81,37 @@ class DataProcessor:
             self.logger.error(f"预测结果加载失败: {str(e)}")
             raise
     
+    def _get_current_workflow_config(self):
+        """
+        动态读取当前的工作流配置
+        
+        @return: 当前的工作流配置字典
+        """
+        try:
+            config_file = os.path.join(config.PROJECT_ROOT, "config.py")
+            
+            # 读取配置文件内容
+            with open(config_file, 'r', encoding='utf-8') as f:
+                config_content = f.read()
+            
+            # 提取WORKFLOW_CONFIG
+            import re
+            pattern = r'WORKFLOW_CONFIG\s*=\s*(\{[^}]*\})'
+            match = re.search(pattern, config_content)
+            
+            if match:
+                # 使用eval安全地解析配置(这里是安全的,因为我们控制配置文件内容)
+                workflow_config_str = match.group(1)
+                workflow_config = eval(workflow_config_str)
+                return workflow_config
+            else:
+                self.logger.warning("无法从配置文件中提取WORKFLOW_CONFIG,使用默认配置")
+                return config.WORKFLOW_CONFIG
+                
+        except Exception as e:
+            self.logger.error(f"读取工作流配置失败: {str(e)},使用默认配置")
+            return config.WORKFLOW_CONFIG
+    
     def load_coordinates(self):
         """
         加载坐标数据

+ 12 - 4
Cd_Prediction_Integrated_System/analysis/visualization.py

@@ -122,7 +122,8 @@ class Visualizer:
                          color_map_name=None,
                          title_name="Prediction Cd",
                          output_path=None,
-                         output_size=None):
+                         output_size=None,
+                         high_res=False):
         """
         创建栅格地图
         
@@ -132,6 +133,7 @@ class Visualizer:
         @param title_name: 输出数据的图的名称
         @param output_path: 输出保存的图片的路径
         @param output_size: 图片尺寸
+        @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
         @return: 输出图片路径
         """
         try:
@@ -222,7 +224,9 @@ class Visualizer:
             
             # 保存图片
             output_file = f"{output_path}.jpg"
-            plt.savefig(output_file, dpi=config.VISUALIZATION_CONFIG["dpi"])
+            # 根据high_res参数决定使用的DPI
+            output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
+            plt.savefig(output_file, dpi=output_dpi, format='jpg', bbox_inches='tight')
             plt.close()
             
             self.logger.info(f"栅格地图创建成功: {output_file}")
@@ -238,7 +242,8 @@ class Visualizer:
                         xlabel='Cd content',
                         ylabel='Frequency',
                         title='County level Cd Frequency',
-                        save_path=None):
+                        save_path=None,
+                        high_res=False):
         """
         绘制GeoTIFF文件的直方图
         
@@ -248,6 +253,7 @@ class Visualizer:
         @param ylabel: 纵坐标标签
         @param title: 图标题
         @param save_path: 可选,保存图片路径(含文件名和扩展名)
+        @param high_res: 是否使用高分辨率输出(默认False,DPI=300)
         @return: 输出图片路径
         """
         try:
@@ -303,7 +309,9 @@ class Visualizer:
             os.makedirs(os.path.dirname(save_path), exist_ok=True)
             
             # 保存图片
-            plt.savefig(save_path, dpi=config.VISUALIZATION_CONFIG["dpi"], 
+            # 根据high_res参数决定使用的DPI
+            output_dpi = 600 if high_res else config.VISUALIZATION_CONFIG["dpi"]
+            plt.savefig(save_path, dpi=output_dpi, 
                        format='jpg', bbox_inches='tight')
             plt.close()
             

+ 1 - 1
Cd_Prediction_Integrated_System/config.py

@@ -70,7 +70,7 @@ VISUALIZATION_CONFIG = {
     },
     "default_colormap": "colormap6",
     "figure_size": 12,
-    "dpi": 300
+    "dpi": 600
 }
 
 # 执行流程配置

+ 23 - 0
PROJECT_RULES.md

@@ -198,6 +198,29 @@ from .services import RasterService
 
 #### 3.3.3 迁移管理
 - 使用Alembic进行数据库迁移
+
+### 3.4 可视化规范 ✨新增
+
+#### 3.4.1 分辨率设置
+- **默认DPI**: 300 (标准分辨率输出)
+- **高分辨率DPI**: 600 (高质量输出) 
+- **高分辨率模式**: 支持通过`high_res=True`参数启用600 DPI输出
+- **图像格式**: JPG格式,支持bbox_inches='tight'自动裁切
+
+#### 3.4.2 分辨率配置层级
+1. **配置文件级别**: `config.VISUALIZATION_CONFIG["dpi"] = 300`
+2. **matplotlib全局设置**: `plt.rcParams['savefig.dpi'] = 300`
+3. **方法参数级别**: `high_res`参数可动态控制输出DPI
+
+#### 3.4.3 可视化方法接口
+- **栅格地图**: `create_raster_map(high_res=False)` - 默认300 DPI标准分辨率输出
+- **直方图**: `create_histogram(high_res=False)` - 默认300 DPI标准分辨率输出
+- **图像尺寸**: 可通过`figure_size`和`figsize`参数调整
+
+#### 3.4.4 字体配置
+- 优先使用Windows系统中文字体: Microsoft YaHei, SimHei等
+- 支持跨平台字体回退机制
+- 自动字体缓存重建解决字体问题
 - 保持迁移文件版本控制
 - 提供数据库重置脚本
 

+ 8 - 8
app/services/cd_prediction_service.py

@@ -510,8 +510,8 @@ class CdPredictionService:
             self.logger.info(f"为{county_name}执行作物Cd预测")
             prediction_result = self.wrapper.run_prediction_script("crop")
             
-            # 获取输出文件
-            latest_outputs = self.wrapper.get_latest_outputs("all")
+            # 获取输出文件(指定作物Cd模型类型)
+            latest_outputs = self.wrapper.get_latest_outputs("all", "crop")
             
             # 复制文件到API输出目录
             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
@@ -549,8 +549,8 @@ class CdPredictionService:
             self.logger.info(f"为{county_name}执行有效态Cd预测")
             prediction_result = self.wrapper.run_prediction_script("effective")
             
-            # 获取输出文件
-            latest_outputs = self.wrapper.get_latest_outputs("all")
+            # 获取输出文件(指定有效态Cd模型类型)
+            latest_outputs = self.wrapper.get_latest_outputs("all", "effective")
             
             # 复制文件到API输出目录
             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
@@ -640,8 +640,8 @@ class CdPredictionService:
             self.logger.info("执行作物Cd预测")
             prediction_result = self.wrapper.run_prediction_script("crop")
             
-            # 获取输出文件
-            latest_outputs = self.wrapper.get_latest_outputs("all")
+            # 获取输出文件(指定作物Cd模型类型)
+            latest_outputs = self.wrapper.get_latest_outputs("all", "crop")
             
             # 复制文件到API输出目录
             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
@@ -675,8 +675,8 @@ class CdPredictionService:
             self.logger.info("执行有效态Cd预测")
             prediction_result = self.wrapper.run_prediction_script("effective")
             
-            # 获取输出文件
-            latest_outputs = self.wrapper.get_latest_outputs("all")
+            # 获取输出文件(指定有效态Cd模型类型)
+            latest_outputs = self.wrapper.get_latest_outputs("all", "effective")
             
             # 复制文件到API输出目录
             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

+ 19 - 5
app/utils/cd_prediction_wrapper.py

@@ -9,6 +9,7 @@ import os
 import sys
 import logging
 import subprocess
+import time
 from typing import Dict, Any, Optional
 from datetime import datetime
 
@@ -203,11 +204,12 @@ class CdPredictionWrapper:
         
         return output_files
     
-    def get_latest_outputs(self, output_type: str = "all") -> Dict[str, Optional[str]]:
+    def get_latest_outputs(self, output_type: str = "all", model_type: str = None) -> Dict[str, Optional[str]]:
         """
         获取最新的输出文件
         
         @param {str} output_type - 输出类型 ("maps", "histograms", "rasters", "all")
+        @param {str} model_type - 模型类型 ("crop", "effective", None为获取所有)
         @returns {Dict[str, Optional[str]]} 最新输出文件路径
         """
         try:
@@ -223,7 +225,8 @@ class CdPredictionWrapper:
                 if os.path.exists(figures_dir):
                     for file in os.listdir(figures_dir):
                         if "Prediction" in file and "results" in file and file.endswith(('.jpg', '.png')):
-                            map_files.append(os.path.join(figures_dir, file))
+                            file_path = os.path.join(figures_dir, file)
+                            map_files.append(file_path)
                 
                 latest_files["latest_map"] = max(map_files, key=os.path.getctime) if map_files else None
             
@@ -233,7 +236,8 @@ class CdPredictionWrapper:
                 if os.path.exists(figures_dir):
                     for file in os.listdir(figures_dir):
                         if ("frequency" in file.lower() or "histogram" in file.lower()) and file.endswith(('.jpg', '.png')):
-                            histogram_files.append(os.path.join(figures_dir, file))
+                            file_path = os.path.join(figures_dir, file)
+                            histogram_files.append(file_path)
                 
                 latest_files["latest_histogram"] = max(histogram_files, key=os.path.getctime) if histogram_files else None
             
@@ -243,12 +247,22 @@ class CdPredictionWrapper:
                 if os.path.exists(raster_dir):
                     for file in os.listdir(raster_dir):
                         if file.startswith('output') and file.endswith('.tif'):
-                            raster_files.append(os.path.join(raster_dir, file))
+                            file_path = os.path.join(raster_dir, file)
+                            raster_files.append(file_path)
                 
                 latest_files["latest_raster"] = max(raster_files, key=os.path.getctime) if raster_files else None
             
+            # 添加调试信息
+            self.logger.info(f"获取最新输出文件 - 模型类型: {model_type}, 输出类型: {output_type}")
+            for key, value in latest_files.items():
+                if value:
+                    self.logger.info(f"  {key}: {os.path.basename(value)}")
+                else:
+                    self.logger.warning(f"  {key}: 未找到文件")
+            
             return latest_files
             
         except Exception as e:
             self.logger.error(f"获取最新输出文件失败: {str(e)}")
-            return {} 
+            return {}
+