""" Cd预测模型配置 @description: 自包含的配置文件,不依赖外部集成系统 @version: 3.0.0 """ import os from typing import Dict, Any, Optional # 获取当前模块的根目录 MODEL_ROOT = os.path.dirname(os.path.abspath(__file__)) # 模型文件路径配置 MODEL_PATHS = { "crop_cd": { "model_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCdNN.pth"), "mean_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_mean.npy"), "scale_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_scale.npy"), "constraint_module": os.path.join(MODEL_ROOT, "crop_cd", "constrained_nn6.py"), "special_feature_name": "solution", # 特殊特征列名 "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引 }, "effective_cd": { "model_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCdNN6C.pth"), "mean_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_mean.npy"), "scale_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_scale.npy"), "constraint_module": os.path.join(MODEL_ROOT, "effective_cd", "constrained_nn6C.py"), "special_feature_name": "Cdsolution", # 特殊特征列名 "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引 }, "common": { "template_tif": os.path.join(MODEL_ROOT, "common", "template.tif") } } # 栅格处理配置 RASTER_CONFIG = { "enable_interpolation": False, # 默认关闭插值以提高性能 "interpolation_method": "nearest", # 插值方法: nearest, linear, cubic "resolution_factor": 1.0, # 分辨率因子,越大分辨率越高 "field_name": "Prediction", # 预测值字段名 "coordinate_columns": { "longitude": "longitude", "latitude": "latitude", "value": "Prediction" } } # 可视化配置 VISUALIZATION_CONFIG = { "color_maps": { "viridis": "viridis", "plasma": "plasma", "inferno": "inferno", "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'], "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'], "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'], "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'], "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'], "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F'] }, "default_colormap": "colormap6", "figure_size": [10, 8], "dpi": 300 } def get_model_config(model_type: str) -> Dict[str, Any]: """ 获取指定模型的配置 @param {str} model_type - 模型类型 ("crop_cd" 或 "effective_cd") @returns {Dict[str, Any]} 模型配置 """ if model_type not in MODEL_PATHS: raise ValueError(f"不支持的模型类型: {model_type}") return MODEL_PATHS[model_type].copy() def get_raster_config(override_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ 获取栅格配置,支持参数覆盖 @param {Optional[Dict[str, Any]]} override_params - 覆盖参数 @returns {Dict[str, Any]} 栅格配置 """ config = RASTER_CONFIG.copy() if override_params: for key, value in override_params.items(): if value is not None: if key in config: config[key] = value elif key in config.get("coordinate_columns", {}): config["coordinate_columns"][key] = value return config def get_template_tif_path() -> str: """ 获取模板栅格文件路径 @returns {str} 模板文件路径 """ return MODEL_PATHS["common"]["template_tif"] def validate_model_files(model_type: str) -> Dict[str, bool]: """ 验证模型文件是否存在 @param {str} model_type - 模型类型 @returns {Dict[str, bool]} 文件存在性验证结果 """ config = get_model_config(model_type) validation_result = {} for key, file_path in config.items(): if key.endswith('_file') or key.endswith('_module'): validation_result[key] = os.path.exists(file_path) return validation_result def ensure_directories(output_base_dir: str): """ 确保输出目录存在 @param {str} output_base_dir - 基础输出目录 """ directories = [ os.path.join(output_base_dir, "figures"), os.path.join(output_base_dir, "raster"), os.path.join(output_base_dir, "data", "temp"), os.path.join(output_base_dir, "data", "final") ] for directory in directories: os.makedirs(directory, exist_ok=True)