config.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """
  2. Cd预测模型配置
  3. @description: 自包含的配置文件,不依赖外部集成系统
  4. @version: 3.0.0
  5. """
  6. import os
  7. from typing import Dict, Any, Optional
  8. # 获取当前模块的根目录
  9. MODEL_ROOT = os.path.dirname(os.path.abspath(__file__))
  10. # 模型文件路径配置
  11. MODEL_PATHS = {
  12. "crop_cd": {
  13. "model_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCdNN.pth"),
  14. "mean_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_mean.npy"),
  15. "scale_file": os.path.join(MODEL_ROOT, "crop_cd", "cropCd_scale.npy"),
  16. "constraint_module": os.path.join(MODEL_ROOT, "crop_cd", "constrained_nn6.py"),
  17. "special_feature_name": "solution", # 特殊特征列名
  18. "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引
  19. },
  20. "effective_cd": {
  21. "model_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCdNN6C.pth"),
  22. "mean_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_mean.npy"),
  23. "scale_file": os.path.join(MODEL_ROOT, "effective_cd", "EffCd_scale.npy"),
  24. "constraint_module": os.path.join(MODEL_ROOT, "effective_cd", "constrained_nn6C.py"),
  25. "special_feature_name": "Cdsolution", # 特殊特征列名
  26. "special_feature_index": 2 # 默认索引,如果找不到列名则使用此索引
  27. },
  28. "common": {
  29. "template_tif": os.path.join(MODEL_ROOT, "common", "template.tif")
  30. }
  31. }
  32. # 栅格处理配置
  33. RASTER_CONFIG = {
  34. "enable_interpolation": False, # 默认关闭插值以提高性能
  35. "interpolation_method": "nearest", # 插值方法: nearest, linear, cubic
  36. "resolution_factor": 1.0, # 分辨率因子,越大分辨率越高
  37. "field_name": "Prediction", # 预测值字段名
  38. "coordinate_columns": {
  39. "longitude": "longitude",
  40. "latitude": "latitude",
  41. "value": "Prediction"
  42. }
  43. }
  44. # 可视化配置
  45. VISUALIZATION_CONFIG = {
  46. "color_maps": {
  47. "viridis": "viridis",
  48. "plasma": "plasma",
  49. "inferno": "inferno",
  50. "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'],
  51. "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'],
  52. "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'],
  53. "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'],
  54. "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'],
  55. "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F']
  56. },
  57. "default_colormap": "colormap6",
  58. "figure_size": [10, 8],
  59. "dpi": 300
  60. }
  61. def get_model_config(model_type: str) -> Dict[str, Any]:
  62. """
  63. 获取指定模型的配置
  64. @param {str} model_type - 模型类型 ("crop_cd" 或 "effective_cd")
  65. @returns {Dict[str, Any]} 模型配置
  66. """
  67. if model_type not in MODEL_PATHS:
  68. raise ValueError(f"不支持的模型类型: {model_type}")
  69. return MODEL_PATHS[model_type].copy()
  70. def get_raster_config(override_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
  71. """
  72. 获取栅格配置,支持参数覆盖
  73. @param {Optional[Dict[str, Any]]} override_params - 覆盖参数
  74. @returns {Dict[str, Any]} 栅格配置
  75. """
  76. config = RASTER_CONFIG.copy()
  77. if override_params:
  78. for key, value in override_params.items():
  79. if value is not None:
  80. if key in config:
  81. config[key] = value
  82. elif key in config.get("coordinate_columns", {}):
  83. config["coordinate_columns"][key] = value
  84. return config
  85. def get_template_tif_path() -> str:
  86. """
  87. 获取模板栅格文件路径
  88. @returns {str} 模板文件路径
  89. """
  90. return MODEL_PATHS["common"]["template_tif"]
  91. def validate_model_files(model_type: str) -> Dict[str, bool]:
  92. """
  93. 验证模型文件是否存在
  94. @param {str} model_type - 模型类型
  95. @returns {Dict[str, bool]} 文件存在性验证结果
  96. """
  97. config = get_model_config(model_type)
  98. validation_result = {}
  99. for key, file_path in config.items():
  100. if key.endswith('_file') or key.endswith('_module'):
  101. validation_result[key] = os.path.exists(file_path)
  102. return validation_result
  103. def ensure_directories(output_base_dir: str):
  104. """
  105. 确保输出目录存在
  106. @param {str} output_base_dir - 基础输出目录
  107. """
  108. directories = [
  109. os.path.join(output_base_dir, "figures"),
  110. os.path.join(output_base_dir, "raster"),
  111. os.path.join(output_base_dir, "data", "temp"),
  112. os.path.join(output_base_dir, "data", "final")
  113. ]
  114. for directory in directories:
  115. os.makedirs(directory, exist_ok=True)