config.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """
  2. 配置文件 - Cd预测集成系统
  3. Configuration file for Cd Prediction Integrated System
  4. """
  5. import os
  6. # 项目根目录
  7. PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
  8. # 模型相关路径
  9. MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
  10. # 作物Cd模型配置
  11. CROP_CD_MODEL = {
  12. "model_dir": os.path.join(MODELS_DIR, "crop_cd_model"),
  13. "model_file": "cropCdNN.pth",
  14. "mean_file": "cropCd_mean.npy",
  15. "scale_file": "cropCd_scale.npy",
  16. "constraint_module": "constrained_nn6",
  17. "input_data": "areatest.csv",
  18. "output_file": "combined_pH.csv",
  19. "special_feature": "solution"
  20. }
  21. # 有效态Cd模型配置
  22. EFFECTIVE_CD_MODEL = {
  23. "model_dir": os.path.join(MODELS_DIR, "effective_cd_model"),
  24. "model_file": "EffCdNN6C.pth",
  25. "mean_file": "EffCd_mean.npy",
  26. "scale_file": "EffCd_scale.npy",
  27. "constraint_module": "constrained_nn6C",
  28. "input_data": "areatest.csv",
  29. "output_file": "pHcombined.csv",
  30. "special_feature": "Cdsolution"
  31. }
  32. # 数据路径配置
  33. DATA_PATHS = {
  34. "coordinates_dir": os.path.join(PROJECT_ROOT, "data", "coordinates"),
  35. "predictions_dir": os.path.join(PROJECT_ROOT, "data", "predictions"),
  36. "final_dir": os.path.join(PROJECT_ROOT, "data", "final"),
  37. "coordinate_file": "坐标.csv"
  38. }
  39. # 分析模块配置
  40. ANALYSIS_CONFIG = {
  41. "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
  42. "boundary_shp": os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp"),
  43. "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
  44. "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
  45. }
  46. # 输出路径配置
  47. OUTPUT_PATHS = {
  48. "raster_dir": os.path.join(PROJECT_ROOT, "output", "raster"),
  49. "figures_dir": os.path.join(PROJECT_ROOT, "output", "figures"),
  50. "reports_dir": os.path.join(PROJECT_ROOT, "output", "reports")
  51. }
  52. # 可视化配置
  53. VISUALIZATION_CONFIG = {
  54. "color_maps": {
  55. "colormap1": ['#FFFECE', '#FFF085', '#FEBA17','#BE3D2A','#74512D', '#4E1F00'],
  56. "colormap2": ['#F6F8D5', '#98D2C0', '#4F959D', '#205781', '#143D60','#2A3335'],
  57. "colormap3": ['#FFEFC8', '#F8ED8C', '#D3E671', '#89AC46', '#5F8B4C', '#355F2E'],
  58. "colormap4": ['#F0F1C5', '#BBD8A3', '#6F826A', '#BF9264', '#735557', '#604652'],
  59. "colormap5": ['#FCFAEE', '#FBF3B9', '#FFDCCC', '#FDB7EA', '#B7B1F2', '#8D77AB'],
  60. "colormap6": ['#15B392', '#73EC8B', '#FFEB55', '#EE66A6', '#D91656', '#640D5F']
  61. },
  62. "default_colormap": "colormap6",
  63. "figure_size": 12,
  64. "dpi": 600
  65. }
  66. # 执行流程配置
  67. _DEFAULT_WORKFLOW_CONFIG = {'run_crop_model': True, 'run_effective_model': False, 'combine_predictions': True, 'generate_raster': True, 'create_visualization': True, 'create_histogram': True}
  68. def get_workflow_config():
  69. """
  70. 获取工作流配置
  71. 优先从环境变量获取,如果没有则使用默认配置
  72. @returns {dict} 工作流配置字典
  73. """
  74. import json
  75. env_config = os.environ.get('CD_WORKFLOW_CONFIG')
  76. if env_config:
  77. try:
  78. return json.loads(env_config)
  79. except json.JSONDecodeError:
  80. pass
  81. return _DEFAULT_WORKFLOW_CONFIG.copy()
  82. # 为了向后兼容,保留WORKFLOW_CONFIG变量(使用默认配置)
  83. WORKFLOW_CONFIG = _DEFAULT_WORKFLOW_CONFIG.copy()
  84. def ensure_directories():
  85. """确保所有必要的目录存在"""
  86. directories = [
  87. MODELS_DIR,
  88. DATA_PATHS["coordinates_dir"],
  89. DATA_PATHS["predictions_dir"],
  90. DATA_PATHS["final_dir"],
  91. OUTPUT_PATHS["raster_dir"],
  92. OUTPUT_PATHS["figures_dir"],
  93. OUTPUT_PATHS["reports_dir"]
  94. ]
  95. for directory in directories:
  96. os.makedirs(directory, exist_ok=True)