config.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. """
  2. 配置文件 - Cd预测集成系统
  3. Configuration file for Cd Prediction Integrated System
  4. """
  5. import os
  6. # 栅格处理配置
  7. RASTER_CONFIG = {
  8. "enable_interpolation": False, # 是否启用空间插值
  9. "interpolation_method": "nearest", # 插值方法: nearest, linear, cubic
  10. "resolution_factor": 1.0, # 分辨率因子,越大分辨率越高
  11. "field_name": "Prediction", # 预测值字段名
  12. "coordinate_columns": {
  13. "longitude": 0, # 经度列索引
  14. "latitude": 1, # 纬度列索引
  15. "value": 2 # 预测值列索引
  16. }
  17. }
  18. # 项目根目录
  19. PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
  20. # 模型相关路径
  21. MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
  22. # 作物Cd模型配置
  23. CROP_CD_MODEL = {
  24. "model_dir": os.path.join(MODELS_DIR, "crop_cd_model"),
  25. "model_file": "cropCdNN.pth",
  26. "mean_file": "cropCd_mean.npy",
  27. "scale_file": "cropCd_scale.npy",
  28. "constraint_module": "constrained_nn6",
  29. "input_data": "areatest.csv",
  30. "output_file": "combined_pH.csv",
  31. "special_feature": "solution"
  32. }
  33. # 有效态Cd模型配置
  34. EFFECTIVE_CD_MODEL = {
  35. "model_dir": os.path.join(MODELS_DIR, "effective_cd_model"),
  36. "model_file": "EffCdNN6C.pth",
  37. "mean_file": "EffCd_mean.npy",
  38. "scale_file": "EffCd_scale.npy",
  39. "constraint_module": "constrained_nn6C",
  40. "input_data": "areatest.csv",
  41. "output_file": "pHcombined.csv",
  42. "special_feature": "Cdsolution"
  43. }
  44. # 数据路径配置
  45. DATA_PATHS = {
  46. "coordinates_dir": os.path.join(PROJECT_ROOT, "data", "coordinates"),
  47. "predictions_dir": os.path.join(PROJECT_ROOT, "data", "predictions"),
  48. "final_dir": os.path.join(PROJECT_ROOT, "data", "final"),
  49. "coordinate_file": "坐标.csv"
  50. }
  51. # 分析模块配置
  52. ANALYSIS_CONFIG = {
  53. "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
  54. "boundary_shp": os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp"),
  55. "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
  56. "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
  57. }
  58. # 输出路径配置
  59. OUTPUT_PATHS = {
  60. "raster_dir": os.path.join(PROJECT_ROOT, "output", "raster"),
  61. "figures_dir": os.path.join(PROJECT_ROOT, "output", "figures"),
  62. "reports_dir": os.path.join(PROJECT_ROOT, "output", "reports")
  63. }
  64. # 可视化配置
  65. VISUALIZATION_CONFIG = {
  66. "color_maps": {
  67. "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'],
  68. "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'],
  69. "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'],
  70. "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'],
  71. "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'],
  72. "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F']
  73. },
  74. "default_colormap": "colormap6",
  75. "figure_size": 12,
  76. "dpi": 600
  77. }
  78. # 执行流程配置
  79. _DEFAULT_WORKFLOW_CONFIG = {'run_crop_model': True, 'run_effective_model': False, 'combine_predictions': True, 'generate_raster': True, 'create_visualization': True, 'create_histogram': True}
  80. def get_workflow_config():
  81. """
  82. 获取工作流配置
  83. 优先从环境变量获取,如果没有则使用默认配置
  84. @returns {dict} 工作流配置字典
  85. """
  86. import json
  87. env_config = os.environ.get('CD_WORKFLOW_CONFIG')
  88. if env_config:
  89. try:
  90. return json.loads(env_config)
  91. except json.JSONDecodeError:
  92. pass
  93. return _DEFAULT_WORKFLOW_CONFIG.copy()
  94. def get_raster_config(override_params=None):
  95. """
  96. 获取栅格配置
  97. 支持参数覆盖,优先级:直接传参 > 环境变量 > 默认配置
  98. @param override_params: API传递的参数字典
  99. @returns {dict} 合并后的栅格配置字典
  100. """
  101. # 从默认配置开始
  102. config = RASTER_CONFIG.copy()
  103. # 检查环境变量中的覆盖参数
  104. import json
  105. env_override = os.environ.get('CD_RASTER_CONFIG_OVERRIDE')
  106. if env_override:
  107. try:
  108. env_params = json.loads(env_override)
  109. for key, value in env_params.items():
  110. if value is not None: # 只覆盖非None值
  111. if key in config:
  112. config[key] = value
  113. elif key in config.get("coordinate_columns", {}):
  114. config["coordinate_columns"][key] = value
  115. except json.JSONDecodeError:
  116. pass
  117. # 如果有直接传递的覆盖参数,优先级最高
  118. if override_params:
  119. for key, value in override_params.items():
  120. if value is not None: # 只覆盖非None值
  121. if key in config:
  122. config[key] = value
  123. elif key in config.get("coordinate_columns", {}):
  124. config["coordinate_columns"][key] = value
  125. return config
  126. # 为了向后兼容,保留WORKFLOW_CONFIG变量(使用默认配置)
  127. WORKFLOW_CONFIG = _DEFAULT_WORKFLOW_CONFIG.copy()
  128. def ensure_directories():
  129. """确保所有必要的目录存在"""
  130. directories = [
  131. MODELS_DIR,
  132. DATA_PATHS["coordinates_dir"],
  133. DATA_PATHS["predictions_dir"],
  134. DATA_PATHS["final_dir"],
  135. OUTPUT_PATHS["raster_dir"],
  136. OUTPUT_PATHS["figures_dir"],
  137. OUTPUT_PATHS["reports_dir"]
  138. ]
  139. for directory in directories:
  140. os.makedirs(directory, exist_ok=True)