123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- """
- Cd预测模型配置
- @description: 自包含的配置文件,不依赖外部集成系统
- @version: 3.0.0
- """
- import os
- from typing import Dict, Any, Optional
- # 获取当前模块的根目录
- MODEL_ROOT = os.path.dirname(os.path.abspath(__file__))
- # 模型文件路径配置
- MODEL_PATHS = {
- "crop_cd": {
- "model_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCdNN.pth"),
- "mean_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_mean.npy"),
- "scale_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_scale.npy"),
- "constraint_module": os.path.join(MODEL_ROOT, "crop_cd", "constrained_nn6.py"),
- "special_feature_name": "solution", # 特殊特征列名
- "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引
- },
- "effective_cd": {
- "model_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCdNN6C.pth"),
- "mean_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_mean.npy"),
- "scale_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_scale.npy"),
- "constraint_module": os.path.join(MODEL_ROOT, "effective_cd", "constrained_nn6C.py"),
- "special_feature_name": "Cdsolution", # 特殊特征列名
- "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引
- },
- "common": {
- "template_tif": os.path.join(MODEL_ROOT, "common", "template.tif")
- }
- }
- # 栅格处理配置
- RASTER_CONFIG = {
- "enable_interpolation": False, # 默认关闭插值以提高性能
- "interpolation_method": "nearest", # 插值方法: nearest, linear, cubic
- "resolution_factor": 1.0, # 分辨率因子,越大分辨率越高
- "field_name": "Prediction", # 预测值字段名
- "coordinate_columns": {
- "longitude": "longitude",
- "latitude": "latitude",
- "value": "Prediction"
- }
- }
- # 可视化配置
- VISUALIZATION_CONFIG = {
- "color_maps": {
- "viridis": "viridis",
- "plasma": "plasma",
- "inferno": "inferno",
- "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'],
- "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'],
- "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'],
- "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'],
- "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'],
- "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F']
- },
- "default_colormap": "colormap6",
- "figure_size": [10, 8],
- "dpi": 300
- }
- def get_model_config(model_type: str) -> Dict[str, Any]:
- """
- 获取指定模型的配置
-
- @param {str} model_type - 模型类型 ("crop_cd" 或 "effective_cd")
- @returns {Dict[str, Any]} 模型配置
- """
- if model_type not in MODEL_PATHS:
- raise ValueError(f"不支持的模型类型: {model_type}")
-
- return MODEL_PATHS[model_type].copy()
- def get_raster_config(override_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- """
- 获取栅格配置,支持参数覆盖
-
- @param {Optional[Dict[str, Any]]} override_params - 覆盖参数
- @returns {Dict[str, Any]} 栅格配置
- """
- config = RASTER_CONFIG.copy()
-
- if override_params:
- for key, value in override_params.items():
- if value is not None:
- if key in config:
- config[key] = value
- elif key in config.get("coordinate_columns", {}):
- config["coordinate_columns"][key] = value
-
- return config
- def get_template_tif_path() -> str:
- """
- 获取模板栅格文件路径
-
- @returns {str} 模板文件路径
- """
- return MODEL_PATHS["common"]["template_tif"]
- def validate_model_files(model_type: str) -> Dict[str, bool]:
- """
- 验证模型文件是否存在
-
- @param {str} model_type - 模型类型
- @returns {Dict[str, bool]} 文件存在性验证结果
- """
- config = get_model_config(model_type)
-
- validation_result = {}
- for key, file_path in config.items():
- if key.endswith('_file') or key.endswith('_module'):
- validation_result[key] = os.path.exists(file_path)
-
- return validation_result
- def ensure_directories(output_base_dir: str):
- """
- 确保输出目录存在
-
- @param {str} output_base_dir - 基础输出目录
- """
- directories = [
- os.path.join(output_base_dir, "figures"),
- os.path.join(output_base_dir, "raster"),
- os.path.join(output_base_dir, "data", "temp"),
- os.path.join(output_base_dir, "data", "final")
- ]
-
- for directory in directories:
- os.makedirs(directory, exist_ok=True)
|