config.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. 配置文件 - Cd预测集成系统
  3. Configuration file for Cd Prediction Integrated System
  4. """
  5. import os
  6. # 栅格处理配置
  7. RASTER_CONFIG = {
  8. "enable_interpolation": True, # 是否启用空间插值
  9. "interpolation_method": "linear", # 插值方法: nearest, linear, cubic
  10. "resolution_factor": 4.0, # 分辨率因子,越大分辨率越高
  11. "field_name": "Prediction", # 预测值字段名
  12. "coordinate_columns": {
  13. "longitude": 0, # 经度列索引
  14. "latitude": 1, # 纬度列索引
  15. "value": 2 # 预测值列索引
  16. }
  17. }
  18. # 项目根目录
  19. PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
  20. # 模型相关路径
  21. MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
  22. # 作物Cd模型配置
  23. CROP_CD_MODEL = {
  24. "model_dir": os.path.join(MODELS_DIR, "crop_cd_model"),
  25. "model_file": "cropCdNN.pth",
  26. "mean_file": "cropCd_mean.npy",
  27. "scale_file": "cropCd_scale.npy",
  28. "constraint_module": "constrained_nn6",
  29. "input_data": "areatest.csv",
  30. "output_file": "combined_pH.csv",
  31. "special_feature": "solution"
  32. }
  33. # 有效态Cd模型配置
  34. EFFECTIVE_CD_MODEL = {
  35. "model_dir": os.path.join(MODELS_DIR, "effective_cd_model"),
  36. "model_file": "EffCdNN6C.pth",
  37. "mean_file": "EffCd_mean.npy",
  38. "scale_file": "EffCd_scale.npy",
  39. "constraint_module": "constrained_nn6C",
  40. "input_data": "areatest.csv",
  41. "output_file": "pHcombined.csv",
  42. "special_feature": "Cdsolution"
  43. }
  44. # 数据路径配置
  45. DATA_PATHS = {
  46. "coordinates_dir": os.path.join(PROJECT_ROOT, "data", "coordinates"),
  47. "predictions_dir": os.path.join(PROJECT_ROOT, "data", "predictions"),
  48. "final_dir": os.path.join(PROJECT_ROOT, "data", "final"),
  49. "coordinate_file": "坐标.csv"
  50. }
  51. # 分析模块配置
  52. ANALYSIS_CONFIG = {
  53. "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
  54. # 允许通过环境变量 CD_BOUNDARY_FILE 覆盖边界文件(支持 .geojson/.shp)
  55. "boundary_shp": os.environ.get('CD_BOUNDARY_FILE', os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp")),
  56. "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
  57. "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
  58. }
  59. # 输出路径配置
  60. OUTPUT_PATHS = {
  61. "raster_dir": os.path.join(PROJECT_ROOT, "output", "raster"),
  62. "figures_dir": os.path.join(PROJECT_ROOT, "output", "figures"),
  63. "reports_dir": os.path.join(PROJECT_ROOT, "output", "reports")
  64. }
  65. # 可视化配置
  66. VISUALIZATION_CONFIG = {
  67. "color_maps": {
  68. "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'],
  69. "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'],
  70. "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'],
  71. "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'],
  72. "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'],
  73. "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F']
  74. },
  75. "default_colormap": "colormap6",
  76. "figure_size": 12,
  77. "dpi": 600
  78. }
  79. # 执行流程配置
  80. _DEFAULT_WORKFLOW_CONFIG = {'run_crop_model': True, 'run_effective_model': False, 'combine_predictions': True, 'generate_raster': True, 'create_visualization': True, 'create_histogram': True}
  81. def get_workflow_config():
  82. """
  83. 获取工作流配置
  84. 优先从环境变量获取,如果没有则使用默认配置
  85. @returns {dict} 工作流配置字典
  86. """
  87. import json
  88. env_config = os.environ.get('CD_WORKFLOW_CONFIG')
  89. if env_config:
  90. try:
  91. return json.loads(env_config)
  92. except json.JSONDecodeError:
  93. pass
  94. return _DEFAULT_WORKFLOW_CONFIG.copy()
  95. def get_raster_config(override_params=None):
  96. """
  97. 获取栅格配置
  98. 支持参数覆盖,优先级:直接传参 > 环境变量 > 默认配置
  99. @param override_params: API传递的参数字典
  100. @returns {dict} 合并后的栅格配置字典
  101. """
  102. # 从默认配置开始
  103. config = RASTER_CONFIG.copy()
  104. # 检查环境变量中的覆盖参数
  105. import json
  106. env_override = os.environ.get('CD_RASTER_CONFIG_OVERRIDE')
  107. if env_override:
  108. try:
  109. env_params = json.loads(env_override)
  110. for key, value in env_params.items():
  111. if value is not None: # 只覆盖非None值
  112. if key in config:
  113. config[key] = value
  114. elif key in config.get("coordinate_columns", {}):
  115. config["coordinate_columns"][key] = value
  116. except json.JSONDecodeError:
  117. pass
  118. # 如果有直接传递的覆盖参数,优先级最高
  119. if override_params:
  120. for key, value in override_params.items():
  121. if value is not None: # 只覆盖非None值
  122. if key in config:
  123. config[key] = value
  124. elif key in config.get("coordinate_columns", {}):
  125. config["coordinate_columns"][key] = value
  126. return config
  127. # 为了向后兼容,保留WORKFLOW_CONFIG变量(使用默认配置)
  128. WORKFLOW_CONFIG = _DEFAULT_WORKFLOW_CONFIG.copy()
  129. def ensure_directories():
  130. """确保所有必要的目录存在"""
  131. directories = [
  132. MODELS_DIR,
  133. DATA_PATHS["coordinates_dir"],
  134. DATA_PATHS["predictions_dir"],
  135. DATA_PATHS["final_dir"],
  136. OUTPUT_PATHS["raster_dir"],
  137. OUTPUT_PATHS["figures_dir"],
  138. OUTPUT_PATHS["reports_dir"]
  139. ]
  140. for directory in directories:
  141. os.makedirs(directory, exist_ok=True)