瀏覽代碼

添加栅格处理配置,支持运行时参数覆盖;更新相关函数以获取和应用栅格配置,增强生成作物Cd预测地图的灵活性。

drggboy 1 周之前
父節點
當前提交
20168b10ef

+ 50 - 0
Cd_Prediction_Integrated_System/config.py

@@ -5,6 +5,19 @@ Configuration file for Cd Prediction Integrated System
 
 import os
 
+# 栅格处理配置
+RASTER_CONFIG = {
+    "enable_interpolation": False,  # 是否启用空间插值
+    "interpolation_method": "nearest",  # 插值方法: nearest, linear, cubic
+    "resolution_factor": 1.0,  # 分辨率因子,越大分辨率越高
+    "field_name": "Prediction",  # 预测值字段名
+    "coordinate_columns": {
+        "longitude": 0,  # 经度列索引
+        "latitude": 1,   # 纬度列索引  
+        "value": 2       # 预测值列索引
+    }
+}
+
 # 项目根目录
 PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
 
@@ -92,6 +105,43 @@ def get_workflow_config():
             pass
     return _DEFAULT_WORKFLOW_CONFIG.copy()
 
+def get_raster_config(override_params=None):
+    """
+    获取栅格配置
+    支持参数覆盖,优先级:直接传参 > 环境变量 > 默认配置
+    
+    @param override_params: API传递的参数字典
+    @returns {dict} 合并后的栅格配置字典
+    """
+    # 从默认配置开始
+    config = RASTER_CONFIG.copy()
+    
+    # 检查环境变量中的覆盖参数
+    import json
+    env_override = os.environ.get('CD_RASTER_CONFIG_OVERRIDE')
+    if env_override:
+        try:
+            env_params = json.loads(env_override)
+            for key, value in env_params.items():
+                if value is not None:  # 只覆盖非None值
+                    if key in config:
+                        config[key] = value
+                    elif key in config.get("coordinate_columns", {}):
+                        config["coordinate_columns"][key] = value
+        except json.JSONDecodeError:
+            pass
+    
+    # 如果有直接传递的覆盖参数,优先级最高
+    if override_params:
+        for key, value in override_params.items():
+            if value is not None:  # 只覆盖非None值
+                if key in config:
+                    config[key] = value
+                elif key in config.get("coordinate_columns", {}):
+                    config["coordinate_columns"][key] = value
+                    
+    return config
+
 # 为了向后兼容,保留WORKFLOW_CONFIG变量(使用默认配置)
 WORKFLOW_CONFIG = _DEFAULT_WORKFLOW_CONFIG.copy()
 

+ 10 - 7
Cd_Prediction_Integrated_System/main.py

@@ -109,18 +109,21 @@ def main():
                 
                 # 使用csv_to_raster_workflow直接生成栅格
                 try:
+                    # 获取栅格配置(支持运行时参数覆盖)
+                    raster_config = config.get_raster_config()
+                    
                     workflow_result = csv_to_raster_workflow(
                         csv_file=final_data_file,
                         template_tif=config.ANALYSIS_CONFIG["template_tif"],
                         output_dir=config.OUTPUT_PATHS["raster_dir"],
                         boundary_shp=config.ANALYSIS_CONFIG.get("boundary_shp"),
-                        resolution_factor=1.0,  # 高分辨率栅格生成
-                        interpolation_method='nearest',
-                        field_name='Prediction',
-                        lon_col=0,
-                        lat_col=1, 
-                        value_col=2,
-                        enable_interpolation=False  # 禁用空间插值
+                        resolution_factor=raster_config["resolution_factor"],
+                        interpolation_method=raster_config["interpolation_method"],
+                        field_name=raster_config["field_name"],
+                        lon_col=raster_config["coordinate_columns"]["longitude"],
+                        lat_col=raster_config["coordinate_columns"]["latitude"], 
+                        value_col=raster_config["coordinate_columns"]["value"],
+                        enable_interpolation=raster_config["enable_interpolation"]
                     )
                     
                     output_raster = workflow_result['raster']

+ 30 - 4
app/api/cd_prediction.py

@@ -59,7 +59,10 @@ async def get_supported_counties() -> Dict[str, Any]:
             description="根据县名和CSV数据生成作物Cd预测地图并直接返回图片文件")
 async def generate_and_get_crop_cd_map(
     county_name: str = Form(..., description="县市名称,如:乐昌市"),
-    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致")
+    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致"),
+    enable_interpolation: Optional[bool] = Form(None, description="是否启用空间插值"),
+    interpolation_method: Optional[str] = Form(None, description="插值方法: nearest, linear, cubic"), 
+    resolution_factor: Optional[float] = Form(None, description="分辨率因子,越大分辨率越高")
 ):
     """
     一键生成并获取作物Cd预测地图
@@ -102,10 +105,20 @@ async def generate_and_get_crop_cd_map(
         # 保存临时数据文件
         temp_file_path = service.save_temp_data(df, county_name)
         
+        # 构建栅格配置参数
+        raster_params = {}
+        if enable_interpolation is not None:
+            raster_params['enable_interpolation'] = enable_interpolation
+        if interpolation_method is not None:
+            raster_params['interpolation_method'] = interpolation_method
+        if resolution_factor is not None:
+            raster_params['resolution_factor'] = resolution_factor
+        
         # 生成预测结果
         result = await service.generate_crop_cd_prediction_for_county(
             county_name=county_name,
-            data_file=temp_file_path
+            data_file=temp_file_path,
+            raster_config_override=raster_params if raster_params else None
         )
         
         if not result['map_path'] or not os.path.exists(result['map_path']):
@@ -131,7 +144,10 @@ async def generate_and_get_crop_cd_map(
             description="根据县名和CSV数据生成有效态Cd预测地图并直接返回图片文件")
 async def generate_and_get_effective_cd_map(
     county_name: str = Form(..., description="县市名称,如:乐昌市"),
-    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致")
+    data_file: UploadFile = File(..., description="CSV格式的环境因子数据文件,前两列为经纬度,后续列与areatest.csv结构一致"),
+    enable_interpolation: Optional[bool] = Form(None, description="是否启用空间插值"),
+    interpolation_method: Optional[str] = Form(None, description="插值方法: nearest, linear, cubic"), 
+    resolution_factor: Optional[float] = Form(None, description="分辨率因子,越大分辨率越高")
 ):
     """
     一键生成并获取有效态Cd预测地图
@@ -174,10 +190,20 @@ async def generate_and_get_effective_cd_map(
         # 保存临时数据文件
         temp_file_path = service.save_temp_data(df, county_name)
         
+        # 构建栅格配置参数
+        raster_params = {}
+        if enable_interpolation is not None:
+            raster_params['enable_interpolation'] = enable_interpolation
+        if interpolation_method is not None:
+            raster_params['interpolation_method'] = interpolation_method
+        if resolution_factor is not None:
+            raster_params['resolution_factor'] = resolution_factor
+        
         # 生成预测结果
         result = await service.generate_effective_cd_prediction_for_county(
             county_name=county_name,
-            data_file=temp_file_path
+            data_file=temp_file_path,
+            raster_config_override=raster_params if raster_params else None
         )
         
         if not result['map_path'] or not os.path.exists(result['map_path']):

+ 12 - 8
app/services/cd_prediction_service.py

@@ -313,7 +313,8 @@ class CdPredictionService:
     async def generate_crop_cd_prediction_for_county(
         self, 
         county_name: str, 
-        data_file: Optional[str] = None
+        data_file: Optional[str] = None,
+        raster_config_override: Optional[Dict[str, Any]] = None
     ) -> Dict[str, Any]:
         """
         为指定县市生成作物Cd预测
@@ -339,7 +340,7 @@ class CdPredictionService:
             result = await loop.run_in_executor(
                 None, 
                 self._run_crop_cd_prediction_with_county,
-                county_name, county_config
+                county_name, county_config, raster_config_override
             )
             
             return result
@@ -351,7 +352,8 @@ class CdPredictionService:
     async def generate_effective_cd_prediction_for_county(
         self, 
         county_name: str, 
-        data_file: Optional[str] = None
+        data_file: Optional[str] = None,
+        raster_config_override: Optional[Dict[str, Any]] = None
     ) -> Dict[str, Any]:
         """
         为指定县市生成有效态Cd预测
@@ -377,7 +379,7 @@ class CdPredictionService:
             result = await loop.run_in_executor(
                 None, 
                 self._run_effective_cd_prediction_with_county,
-                county_name, county_config
+                county_name, county_config, raster_config_override
             )
             
             return result
@@ -527,7 +529,8 @@ class CdPredictionService:
             raise
     
     def _run_crop_cd_prediction_with_county(self, county_name: str, 
-                                          county_config: Dict[str, Any]) -> Dict[str, Any]:
+                                          county_config: Dict[str, Any],
+                                          raster_config_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
         """
         执行指定县市的作物Cd预测
         
@@ -538,7 +541,7 @@ class CdPredictionService:
         try:
             # 运行作物Cd预测
             self.logger.info(f"为{county_name}执行作物Cd预测")
-            prediction_result = self.wrapper.run_prediction_script("crop")
+            prediction_result = self.wrapper.run_prediction_script("crop", raster_config_override)
             
             # 获取输出文件(指定作物Cd模型类型)
             latest_outputs = self.wrapper.get_latest_outputs("all", "crop")
@@ -566,7 +569,8 @@ class CdPredictionService:
             raise
     
     def _run_effective_cd_prediction_with_county(self, county_name: str, 
-                                               county_config: Dict[str, Any]) -> Dict[str, Any]:
+                                               county_config: Dict[str, Any],
+                                               raster_config_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
         """
         执行指定县市的有效态Cd预测
         
@@ -577,7 +581,7 @@ class CdPredictionService:
         try:
             # 运行有效态Cd预测
             self.logger.info(f"为{county_name}执行有效态Cd预测")
-            prediction_result = self.wrapper.run_prediction_script("effective")
+            prediction_result = self.wrapper.run_prediction_script("effective", raster_config_override)
             
             # 获取输出文件(指定有效态Cd模型类型)
             latest_outputs = self.wrapper.get_latest_outputs("all", "effective")

+ 7 - 1
app/utils/cd_prediction_wrapper.py

@@ -59,7 +59,7 @@ class CdPredictionWrapper:
         
         self.logger.info("Cd预测系统验证通过")
     
-    def run_prediction_script(self, model_type: str = "both") -> Dict[str, Any]:
+    def run_prediction_script(self, model_type: str = "both", raster_config_override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
         """
         运行Cd预测脚本
         
@@ -80,6 +80,10 @@ class CdPredictionWrapper:
                 import json
                 os.environ['CD_WORKFLOW_CONFIG'] = json.dumps(workflow_config)
                 
+                # 如果有栅格配置覆盖参数,也通过环境变量传递
+                if raster_config_override:
+                    os.environ['CD_RASTER_CONFIG_OVERRIDE'] = json.dumps(raster_config_override)
+                
                 # 运行主脚本
                 result = subprocess.run(
                     [sys.executable, "main.py"],
@@ -112,6 +116,8 @@ class CdPredictionWrapper:
                 # 清理环境变量
                 if 'CD_WORKFLOW_CONFIG' in os.environ:
                     del os.environ['CD_WORKFLOW_CONFIG']
+                if 'CD_RASTER_CONFIG_OVERRIDE' in os.environ:
+                    del os.environ['CD_RASTER_CONFIG_OVERRIDE']
                 
         except subprocess.TimeoutExpired:
             self.logger.error("Cd预测脚本执行超时")

+ 0 - 0
reset_db.py → scripts/demos/reset_db.py


+ 0 - 193
tests/integration/test_unit_grouping.py

@@ -1,193 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""
-单元分类功能集成测试
-
-该测试文件用于验证单元分类功能是否正确集成到项目中
-"""
-import requests
-import json
-from typing import Dict, Any
-
-def test_unit_grouping_api():
-    """
-    测试单元分类API接口
-    """
-    base_url = "http://localhost:8000"
-    
-    print("=" * 60)
-    print("单元分类功能集成测试")
-    print("=" * 60)
-    
-    # 测试1: 获取单元h_xtfx分类结果
-    print("\n1. 测试获取单元h_xtfx分类结果")
-    print("-" * 40)
-    
-    try:
-        response = requests.get(f"{base_url}/api/unit-grouping/h_xtfx")
-        
-        if response.status_code == 200:
-            data = response.json()
-            print(f"✓ 请求成功")
-            print(f"✓ 成功状态: {data.get('success', False)}")
-            
-            if data.get('success', False):
-                statistics = data.get('statistics', {})
-                print(f"✓ 总单元数: {statistics.get('total_units', 0)}")
-                print(f"✓ 有数据的单元数: {statistics.get('units_with_data', 0)}")
-                print(f"✓ 无数据的单元数: {statistics.get('units_without_data', 0)}")
-                
-                category_dist = statistics.get('category_distribution', {})
-                print(f"✓ 类别分布:")
-                for category, count in category_dist.items():
-                    print(f"  - {category}: {count}")
-                    
-                # 显示前5个单元的结果
-                unit_data = data.get('data', {})
-                if unit_data:
-                    print(f"✓ 前5个单元结果样例:")
-                    for i, (unit_id, h_xtfx) in enumerate(list(unit_data.items())[:5]):
-                        print(f"  - 单元 {unit_id}: {h_xtfx}")
-            else:
-                print(f"✗ 服务返回失败: {data.get('error', 'Unknown error')}")
-        else:
-            print(f"✗ 请求失败,状态码: {response.status_code}")
-            print(f"✗ 错误信息: {response.text}")
-            
-    except requests.exceptions.ConnectionError:
-        print("✗ 连接失败,请确保服务正在运行 (uvicorn main:app --reload)")
-    except Exception as e:
-        print(f"✗ 测试异常: {str(e)}")
-    
-    # 测试2: 获取统计信息
-    print("\n2. 测试获取统计信息")
-    print("-" * 40)
-    
-    try:
-        response = requests.get(f"{base_url}/api/unit-grouping/statistics")
-        
-        if response.status_code == 200:
-            data = response.json()
-            print(f"✓ 请求成功")
-            print(f"✓ 成功状态: {data.get('success', False)}")
-            
-            if data.get('success', False):
-                statistics = data.get('statistics', {})
-                print(f"✓ 统计信息获取成功")
-                print(f"  - 总单元数: {statistics.get('total_units', 0)}")
-                print(f"  - 有数据的单元数: {statistics.get('units_with_data', 0)}")
-                print(f"  - 无数据的单元数: {statistics.get('units_without_data', 0)}")
-        else:
-            print(f"✗ 请求失败,状态码: {response.status_code}")
-            print(f"✗ 错误信息: {response.text}")
-            
-    except requests.exceptions.ConnectionError:
-        print("✗ 连接失败,请确保服务正在运行")
-    except Exception as e:
-        print(f"✗ 测试异常: {str(e)}")
-    
-    # 测试3: 获取特定单元的h_xtfx值
-    print("\n3. 测试获取特定单元的h_xtfx值")
-    print("-" * 40)
-    
-    try:
-        # 首先获取一个存在的单元ID
-        response = requests.get(f"{base_url}/api/unit-grouping/h_xtfx")
-        if response.status_code == 200:
-            data = response.json()
-            if data.get('success', False):
-                unit_data = data.get('data', {})
-                if unit_data:
-                    # 取第一个单元进行测试
-                    test_unit_id = list(unit_data.keys())[0]
-                    
-                    # 测试获取特定单元
-                    response = requests.get(f"{base_url}/api/unit-grouping/unit/{test_unit_id}")
-                    
-                    if response.status_code == 200:
-                        unit_result = response.json()
-                        print(f"✓ 请求成功")
-                        print(f"✓ 单元 {test_unit_id} 的h_xtfx值: {unit_result.get('h_xtfx')}")
-                    else:
-                        print(f"✗ 请求失败,状态码: {response.status_code}")
-                else:
-                    print("✗ 没有可用的单元数据进行测试")
-            else:
-                print("✗ 无法获取单元数据进行测试")
-        else:
-            print("✗ 无法获取单元数据进行测试")
-            
-    except requests.exceptions.ConnectionError:
-        print("✗ 连接失败,请确保服务正在运行")
-    except Exception as e:
-        print(f"✗ 测试异常: {str(e)}")
-    
-    print("\n" + "=" * 60)
-    print("测试完成")
-    print("=" * 60)
-    
-    # 测试4: 测试新的ORM接口
-    print("\n4. 测试新的ORM接口功能")
-    print("-" * 40)
-    
-    # 测试点位统计信息
-    try:
-        response = requests.get(f"{base_url}/api/unit-grouping/points/statistics")
-        
-        if response.status_code == 200:
-            data = response.json()
-            print(f"✓ 点位统计信息获取成功")
-            distribution = data.get('distribution', {})
-            print(f"✓ 总点位数: {data.get('total_points', 0)}")
-            for category, stats in distribution.items():
-                print(f"  - {category}: {stats['count']} ({stats['percentage']}%)")
-        else:
-            print(f"✗ 获取点位统计信息失败,状态码: {response.status_code}")
-    except Exception as e:
-        print(f"✗ 测试点位统计信息异常: {str(e)}")
-    
-    # 测试数据库摘要信息
-    try:
-        response = requests.get(f"{base_url}/api/unit-grouping/database/summary")
-        
-        if response.status_code == 200:
-            data = response.json()
-            print(f"✓ 数据库摘要信息获取成功")
-            summary = data.get('summary', {})
-            print(f"  - 总单元数: {summary.get('total_units', 0)}")
-            print(f"  - 总点位数: {summary.get('total_points', 0)}")
-            print(f"  - h_xtfx分类数: {summary.get('h_xtfx_categories', 0)}")
-        else:
-            print(f"✗ 获取数据库摘要信息失败,状态码: {response.status_code}")
-    except Exception as e:
-        print(f"✗ 测试数据库摘要信息异常: {str(e)}")
-    
-    # 测试批量获取单元信息
-    try:
-        response = requests.get(f"{base_url}/api/unit-grouping/units/batch?unit_ids=1&unit_ids=2&unit_ids=3")
-        
-        if response.status_code == 200:
-            data = response.json()
-            print(f"✓ 批量获取单元信息成功")
-            print(f"  - 请求数量: {data.get('total_requested', 0)}")
-            print(f"  - 找到数量: {data.get('total_found', 0)}")
-        else:
-            print(f"✗ 批量获取单元信息失败,状态码: {response.status_code}")
-    except Exception as e:
-        print(f"✗ 测试批量获取单元信息异常: {str(e)}")
-    
-    # 输出使用说明
-    print("\n使用说明:")
-    print("1. 启动服务: uvicorn main:app --reload")
-    print("2. 访问API文档: http://localhost:8000/docs")
-    print("3. 主要接口:")
-    print("   - GET /api/unit-grouping/h_xtfx - 获取所有单元的h_xtfx分类结果")
-    print("   - GET /api/unit-grouping/statistics - 获取统计信息")
-    print("   - GET /api/unit-grouping/unit/{unit_id} - 获取特定单元的h_xtfx值")
-    print("   - GET /api/unit-grouping/points/statistics - 获取点位统计信息(ORM)")
-    print("   - GET /api/unit-grouping/database/summary - 获取数据库摘要信息(ORM)")
-    print("   - GET /api/unit-grouping/units/batch - 批量获取单元信息(ORM)")
-    print("   - GET /api/unit-grouping/points/by-area - 按区域获取点位数据(ORM)")
-
-if __name__ == "__main__":
-    test_unit_grouping_api()