Pārlūkot izejas kodu

删除生成直方图接口;
修复每次运行两个模型bug;
修复图片乱码问题

drggboy 1 dienu atpakaļ
vecāks
revīzija
be02dae1da

+ 75 - 4
Cd_Prediction_Integrated_System/analysis/visualization.py

@@ -40,10 +40,81 @@ class Visualizer:
         """
         设置matplotlib的字体和样式
         """
-        # 设置全局字体为 Arial,并设置中文字体
-        plt.rcParams['font.family'] = 'Arial'
-        plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
-        plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']  # 添加多个中文字体
+        try:
+            # 设置字体,优先尝试常用的中文字体
+            import matplotlib.font_manager as fm
+            
+            # 清理matplotlib字体缓存(解决Windows系统字体问题)
+            try:
+                import matplotlib
+                fm._rebuild()
+                self.logger.info("matplotlib字体缓存已重建")
+            except Exception as cache_error:
+                self.logger.warning(f"字体缓存重建失败: {cache_error}")
+            
+            # 可用的中文字体列表(Windows系统优先)
+            chinese_fonts = [
+                'Microsoft YaHei',      # 微软雅黑 (Windows)
+                'Microsoft YaHei UI',   # 微软雅黑UI (Windows)
+                'SimHei',               # 黑体 (Windows)
+                'SimSun',               # 宋体 (Windows)
+                'KaiTi',                # 楷体 (Windows)
+                'FangSong',             # 仿宋 (Windows)
+                'Microsoft JhengHei',   # 微软正黑体 (Windows)
+                'PingFang SC',          # 苹方(macOS)
+                'Hiragino Sans GB',     # 冬青黑体(macOS)
+                'WenQuanYi Micro Hei',  # 文泉驿微米黑(Linux)
+                'Noto Sans CJK SC',     # 思源黑体(Linux)
+                'Arial Unicode MS',     # Unicode字体
+                'DejaVu Sans'           # 备用字体
+            ]
+            
+            # 查找可用的字体
+            available_fonts = [f.name for f in fm.fontManager.ttflist]
+            selected_font = None
+            
+            self.logger.info(f"系统中可用字体数量: {len(available_fonts)}")
+            
+            for font in chinese_fonts:
+                if font in available_fonts:
+                    selected_font = font
+                    self.logger.info(f"选择字体: {font}")
+                    break
+            
+            if selected_font:
+                plt.rcParams['font.sans-serif'] = [selected_font] + chinese_fonts
+                plt.rcParams['font.family'] = 'sans-serif'
+            else:
+                self.logger.warning("未找到合适的中文字体,将使用系统默认字体")
+                # 使用更安全的字体配置
+                plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
+                plt.rcParams['font.family'] = 'sans-serif'
+            
+            # 解决负号显示问题
+            plt.rcParams['axes.unicode_minus'] = False
+            
+            # 设置图形样式
+            plt.rcParams['figure.figsize'] = (10, 8)
+            plt.rcParams['axes.labelsize'] = 12
+            plt.rcParams['axes.titlesize'] = 14
+            plt.rcParams['xtick.labelsize'] = 10
+            plt.rcParams['ytick.labelsize'] = 10
+            plt.rcParams['legend.fontsize'] = 10
+            
+            # 设置DPI以提高图像质量
+            plt.rcParams['figure.dpi'] = 100
+            plt.rcParams['savefig.dpi'] = 300
+            plt.rcParams['savefig.bbox'] = 'tight'
+            plt.rcParams['savefig.pad_inches'] = 0.1
+            
+            self.logger.info("matplotlib字体和样式设置完成")
+            
+        except Exception as e:
+            self.logger.warning(f"设置matplotlib字体失败: {str(e)},将使用默认配置")
+            # 最基本的安全配置
+            plt.rcParams['font.family'] = 'sans-serif'
+            plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial']
+            plt.rcParams['axes.unicode_minus'] = False
         
     def create_raster_map(self, 
                          shp_path=None, 

+ 21 - 2
Cd_Prediction_Integrated_System/main.py

@@ -68,6 +68,14 @@ def main():
         for model_name, final_data_file in final_data_files.items():
             if model_name == 'combined':  # 跳过合并文件的可视化
                 continue
+            
+            # 只为当前工作流配置中启用的模型生成输出文件
+            if model_name == 'crop_cd' and not config.WORKFLOW_CONFIG.get("run_crop_model", False):
+                logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
+                continue
+            if model_name == 'effective_cd' and not config.WORKFLOW_CONFIG.get("run_effective_model", False):
+                logger.info(f"跳过{model_name}模型的可视化(工作流配置中未启用)")
+                continue
                 
             logger.info(f"为{model_name}模型生成栅格和可视化...")
             
@@ -135,10 +143,21 @@ def main():
                     f"Prediction_frequency_{model_name}.jpg"
                 )
                 
+                # 为了避免中文字体问题,使用英文标题
+                if model_name == 'crop_cd':
+                    hist_title = 'Crop Cd Prediction Frequency'
+                    hist_xlabel = 'Crop Cd Content (mg/kg)'
+                elif model_name == 'effective_cd':
+                    hist_title = 'Effective Cd Prediction Frequency'
+                    hist_xlabel = 'Effective Cd Content (mg/kg)'
+                else:
+                    hist_title = f'{model_name} Prediction Frequency'
+                    hist_xlabel = f'{model_name} Content'
+                
                 histogram_output = visualizer.create_histogram(
                     file_path=output_raster,
-                    title=f'{display_name} Prediction Frequency',
-                    xlabel=f'{display_name} Content',
+                    title=hist_title,
+                    xlabel=hist_xlabel,
                     save_path=histogram_output_path
                 )
                 model_outputs[model_name]['histogram'] = histogram_output

+ 0 - 148
app/api/cd_prediction.py

@@ -198,154 +198,6 @@ async def generate_and_get_effective_cd_map(
             detail=f"为{county_name}一键生成有效态Cd预测地图失败: {str(e)}"
         )
 
-# =============================================================================
-# 一键生成并获取直方图接口
-# =============================================================================
-
-@router.post("/crop-cd/generate-and-get-histogram", 
-            summary="一键生成并获取作物Cd预测直方图", 
-            description="根据县名和CSV数据生成作物Cd预测直方图并直接返回图片文件")
-async def generate_and_get_crop_cd_histogram(
-    county_name: str = Form(..., description="县市名称,如:乐昌市"),
-    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致")
-):
-    """
-    一键生成并获取作物Cd预测直方图
-    
-    @param county_name: 县市名称
-    @param data_file: CSV数据文件,前两列为经纬度坐标,后面几列和areatest.csv的结构一致
-    @returns {FileResponse} 预测直方图文件
-    """
-    try:
-        logger.info(f"开始为{county_name}一键生成作物Cd预测直方图")
-        
-        # 验证文件格式
-        if not data_file.filename.endswith('.csv'):
-            raise HTTPException(status_code=400, detail="仅支持CSV格式文件")
-        
-        # 读取CSV数据
-        content = await data_file.read()
-        df = pd.read_csv(io.StringIO(content.decode('utf-8')))
-        
-        # 验证数据格式
-        if df.shape[1] < 3:
-            raise HTTPException(
-                status_code=400, 
-                detail="数据至少需要3列:前两列为经纬度,后续列为环境因子"
-            )
-        
-        # 重命名前两列为标准的经纬度列名
-        df.columns = ['longitude', 'latitude'] + list(df.columns[2:])
-        
-        service = CdPredictionService()
-        
-        # 验证数据
-        validation_result = service.validate_input_data(df, county_name)
-        if not validation_result['valid']:
-            raise HTTPException(
-                status_code=400,
-                detail=f"数据验证失败: {', '.join(validation_result['errors'])}"
-            )
-        
-        # 保存临时数据文件
-        temp_file_path = service.save_temp_data(df, county_name)
-        
-        # 生成预测结果
-        result = await service.generate_crop_cd_prediction_for_county(
-            county_name=county_name,
-            data_file=temp_file_path
-        )
-        
-        if not result['histogram_path'] or not os.path.exists(result['histogram_path']):
-            raise HTTPException(status_code=500, detail="直方图文件生成失败")
-        
-        return FileResponse(
-            path=result['histogram_path'],
-            filename=f"{county_name}_crop_cd_prediction_histogram.jpg",
-            media_type="image/jpeg"
-        )
-        
-    except HTTPException:
-        raise
-    except Exception as e:
-        logger.error(f"为{county_name}一键生成作物Cd预测直方图失败: {str(e)}")
-        raise HTTPException(
-            status_code=500, 
-            detail=f"为{county_name}一键生成作物Cd预测直方图失败: {str(e)}"
-        )
-
-@router.post("/effective-cd/generate-and-get-histogram", 
-            summary="一键生成并获取有效态Cd预测直方图", 
-            description="根据县名和CSV数据生成有效态Cd预测直方图并直接返回图片文件")
-async def generate_and_get_effective_cd_histogram(
-    county_name: str = Form(..., description="县市名称,如:乐昌市"),
-    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致")
-):
-    """
-    一键生成并获取有效态Cd预测直方图
-    
-    @param county_name: 县市名称
-    @param data_file: CSV数据文件,前两列为经纬度坐标,后面几列和areatest.csv的结构一致
-    @returns {FileResponse} 预测直方图文件
-    """
-    try:
-        logger.info(f"开始为{county_name}一键生成有效态Cd预测直方图")
-        
-        # 验证文件格式
-        if not data_file.filename.endswith('.csv'):
-            raise HTTPException(status_code=400, detail="仅支持CSV格式文件")
-        
-        # 读取CSV数据
-        content = await data_file.read()
-        df = pd.read_csv(io.StringIO(content.decode('utf-8')))
-        
-        # 验证数据格式
-        if df.shape[1] < 3:
-            raise HTTPException(
-                status_code=400, 
-                detail="数据至少需要3列:前两列为经纬度,后续列为环境因子"
-            )
-        
-        # 重命名前两列为标准的经纬度列名
-        df.columns = ['longitude', 'latitude'] + list(df.columns[2:])
-        
-        service = CdPredictionService()
-        
-        # 验证数据
-        validation_result = service.validate_input_data(df, county_name)
-        if not validation_result['valid']:
-            raise HTTPException(
-                status_code=400,
-                detail=f"数据验证失败: {', '.join(validation_result['errors'])}"
-            )
-        
-        # 保存临时数据文件
-        temp_file_path = service.save_temp_data(df, county_name)
-        
-        # 生成预测结果
-        result = await service.generate_effective_cd_prediction_for_county(
-            county_name=county_name,
-            data_file=temp_file_path
-        )
-        
-        if not result['histogram_path'] or not os.path.exists(result['histogram_path']):
-            raise HTTPException(status_code=500, detail="直方图文件生成失败")
-        
-        return FileResponse(
-            path=result['histogram_path'],
-            filename=f"{county_name}_effective_cd_prediction_histogram.jpg",
-            media_type="image/jpeg"
-        )
-        
-    except HTTPException:
-        raise
-    except Exception as e:
-        logger.error(f"为{county_name}一键生成有效态Cd预测直方图失败: {str(e)}")
-        raise HTTPException(
-            status_code=500, 
-            detail=f"为{county_name}一键生成有效态Cd预测直方图失败: {str(e)}"
-        )
-
 # =============================================================================
 # 获取最新预测结果接口(无需重新计算)
 # =============================================================================

+ 54 - 19
app/services/cd_prediction_service.py

@@ -62,20 +62,23 @@ class CdPredictionService:
         
         @returns {Dict[str, Dict]} 支持的县市配置信息
         """
+        # 获取Cd预测系统的基础路径
+        cd_system_base = self.config.get_cd_system_path()
+        
         return {
             "乐昌市": {
-                "boundary_file": "output/raster/lechang.shp",
-                "template_file": "output/raster/meanTemp.tif",
-                "coordinate_file": "data/coordinates/坐标.csv",
+                "boundary_file": os.path.join(cd_system_base, "output/raster/lechang.shp"),
+                "template_file": os.path.join(cd_system_base, "output/raster/meanTemp.tif"),
+                "coordinate_file": os.path.join(cd_system_base, "data/coordinates/坐标.csv"),
                 "region_code": "440282",
                 "display_name": "乐昌市",
                 "province": "广东省"
             },
             # 可扩展添加更多县市
             # "韶关市": {
-            #     "boundary_file": "output/raster/shaoguan.shp",
-            #     "template_file": "output/raster/shaoguan_template.tif",
-            #     "coordinate_file": "data/coordinates/韶关_坐标.csv",
+            #     "boundary_file": os.path.join(cd_system_base, "output/raster/shaoguan.shp"),
+            #     "template_file": os.path.join(cd_system_base, "output/raster/shaoguan_template.tif"),
+            #     "coordinate_file": os.path.join(cd_system_base, "data/coordinates/韶关_坐标.csv"),
             #     "region_code": "440200",
             #     "display_name": "韶关市",
             #     "province": "广东省"
@@ -319,9 +322,8 @@ class CdPredictionService:
             
             # 如果提供了自定义数据文件,使用它替换默认数据
             if data_file:
-                # 这里需要将自定义数据文件复制到Cd预测系统的输入目录
-                # 并更新配置以使用新的数据文件
-                self._prepare_custom_data(data_file, county_name)
+                # 准备作物Cd模型的自定义数据
+                self._prepare_crop_cd_custom_data(data_file, county_name)
             
             # 在线程池中运行CPU密集型任务
             loop = asyncio.get_event_loop()
@@ -358,7 +360,8 @@ class CdPredictionService:
             
             # 如果提供了自定义数据文件,使用它替换默认数据
             if data_file:
-                self._prepare_custom_data(data_file, county_name)
+                # 准备有效态Cd模型的自定义数据
+                self._prepare_effective_cd_custom_data(data_file, county_name)
             
             # 在线程池中运行CPU密集型任务
             loop = asyncio.get_event_loop()
@@ -374,9 +377,9 @@ class CdPredictionService:
             self.logger.error(f"为{county_name}生成有效态Cd预测失败: {str(e)}")
             raise
     
-    def _prepare_custom_data(self, data_file: str, county_name: str):
+    def _prepare_crop_cd_custom_data(self, data_file: str, county_name: str):
         """
-        准备自定义数据文件
+        准备作物Cd模型的自定义数据文件
         
         @param {str} data_file - 数据文件路径
         @param {str} county_name - 县市名称
@@ -402,7 +405,7 @@ class CdPredictionService:
             coordinates_df.to_csv(coord_file_path, index=False, encoding='utf-8-sig')
             self.logger.info(f"坐标文件已保存: {coord_file_path}")
             
-            # 2. 准备作物Cd模型的训练数据(包含所有列,用于预测)
+            # 2. 准备作物Cd模型的训练数据
             crop_cd_data_dir = os.path.join(cd_system_path, "models", "crop_cd_model", "data")
             crop_target_file = os.path.join(crop_cd_data_dir, "areatest.csv")
             
@@ -410,6 +413,7 @@ class CdPredictionService:
             if os.path.exists(crop_target_file):
                 backup_file = f"{crop_target_file}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
                 shutil.copy2(crop_target_file, backup_file)
+                self.logger.info(f"作物Cd模型原始数据已备份: {backup_file}")
             
             # 提取环境因子数据(去掉前两列的经纬度)
             environmental_data = df.iloc[:, 2:].copy()  # 从第3列开始的所有列
@@ -418,7 +422,41 @@ class CdPredictionService:
             environmental_data.to_csv(crop_target_file, index=False, encoding='utf-8-sig')
             self.logger.info(f"作物Cd模型数据文件已保存: {crop_target_file}, 数据形状: {environmental_data.shape}")
             
-            # 3. 准备有效态Cd模型的训练数据
+            self.logger.info(f"作物Cd模型自定义数据文件已准备完成,县市: {county_name}")
+            
+        except Exception as e:
+            self.logger.error(f"准备作物Cd模型自定义数据文件失败: {str(e)}")
+            raise
+    
+    def _prepare_effective_cd_custom_data(self, data_file: str, county_name: str):
+        """
+        准备有效态Cd模型的自定义数据文件
+        
+        @param {str} data_file - 数据文件路径
+        @param {str} county_name - 县市名称
+        """
+        try:
+            import pandas as pd
+            
+            # 读取用户上传的CSV文件
+            df = pd.read_csv(data_file, encoding='utf-8')
+            
+            # 获取Cd预测系统的数据目录
+            cd_system_path = self.config.get_cd_system_path()
+            
+            # 1. 提取坐标信息并保存为独立的坐标文件
+            coordinates_df = pd.DataFrame({
+                'longitude': df.iloc[:, 0],  # 第一列为经度
+                'latitude': df.iloc[:, 1]    # 第二列为纬度
+            })
+            
+            # 保存坐标文件到系统数据目录
+            coord_file_path = os.path.join(cd_system_path, "data", "coordinates", "坐标.csv")
+            os.makedirs(os.path.dirname(coord_file_path), exist_ok=True)
+            coordinates_df.to_csv(coord_file_path, index=False, encoding='utf-8-sig')
+            self.logger.info(f"坐标文件已保存: {coord_file_path}")
+            
+            # 2. 准备有效态Cd模型的训练数据
             effective_cd_data_dir = os.path.join(cd_system_path, "models", "effective_cd_model", "data")
             effective_target_file = os.path.join(effective_cd_data_dir, "areatest.csv")
             
@@ -452,13 +490,10 @@ class CdPredictionService:
                 else:
                     self.logger.warning(f"用户数据环境因子列数不足且未找到备份文件,继续使用当前数据")
             
-            # 4. 更新系统配置,确保使用新的坐标文件
-            # 这里可能需要根据底层系统的配置方式进行调整
-            
-            self.logger.info(f"自定义数据文件已准备完成,县市: {county_name}")
+            self.logger.info(f"有效态Cd模型自定义数据文件已准备完成,县市: {county_name}")
             
         except Exception as e:
-            self.logger.error(f"准备自定义数据文件失败: {str(e)}")
+            self.logger.error(f"准备有效态Cd模型自定义数据文件失败: {str(e)}")
             raise
     
     def _run_crop_cd_prediction_with_county(self, county_name: str,