|
@@ -3,6 +3,7 @@
|
|
|
Main execution script for Cd Prediction Integrated System
|
|
|
|
|
|
Description: 集成运行作物Cd模型、有效态Cd模型和数据分析可视化的完整流程
|
|
|
+架构简化版本:直接调用通用绘图模块,删除中间封装层
|
|
|
"""
|
|
|
|
|
|
import os
|
|
@@ -12,15 +13,18 @@ from datetime import datetime
|
|
|
|
|
|
# 添加项目根目录到Python路径
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+# 添加上级目录到路径以访问通用绘图模块
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
import config
|
|
|
from models.crop_cd_model.predict import CropCdPredictor
|
|
|
from models.effective_cd_model.predict import EffectiveCdPredictor
|
|
|
from analysis.data_processing import DataProcessor
|
|
|
-from analysis.mapping import RasterMapper
|
|
|
-from analysis.visualization import Visualizer
|
|
|
from utils.common import setup_logging, validate_files
|
|
|
|
|
|
+# 直接导入通用绘图模块
|
|
|
+from app.utils.mapping_utils import csv_to_raster_workflow, MappingUtils
|
|
|
+
|
|
|
def main():
|
|
|
"""
|
|
|
主执行函数
|
|
@@ -33,7 +37,7 @@ def main():
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
- logger.info("开始执行Cd预测集成系统")
|
|
|
+ logger.info("开始执行Cd预测集成系统 (架构简化版本)")
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
try:
|
|
@@ -62,6 +66,9 @@ def main():
|
|
|
final_data_files = data_processor.combine_predictions_with_coordinates()
|
|
|
logger.info(f"数据整合完成,生成文件: {list(final_data_files.keys())}")
|
|
|
|
|
|
+ # 初始化通用绘图工具(用于可视化)
|
|
|
+ mapping_utils = MappingUtils()
|
|
|
+
|
|
|
# 步骤4: 为每个模型生成栅格文件和可视化图
|
|
|
model_outputs = {}
|
|
|
|
|
@@ -92,32 +99,65 @@ def main():
|
|
|
|
|
|
model_outputs[model_name] = {}
|
|
|
|
|
|
- # 步骤4a: 生成栅格文件
|
|
|
+ # 步骤4a: 生成栅格文件(直接调用通用绘图模块)
|
|
|
if config.WORKFLOW_CONFIG["generate_raster"]:
|
|
|
logger.info(f"步骤4a: 为{display_name}模型生成栅格文件...")
|
|
|
- raster_mapper = RasterMapper()
|
|
|
|
|
|
- # 为每个模型创建独立的输出路径
|
|
|
- output_raster_path = os.path.join(
|
|
|
- config.OUTPUT_PATHS["raster_dir"],
|
|
|
- f"output_{model_name}.tif"
|
|
|
- )
|
|
|
-
|
|
|
- output_raster = raster_mapper.csv_to_raster(
|
|
|
- final_data_file,
|
|
|
- output_raster=output_raster_path,
|
|
|
- output_shp=os.path.join(
|
|
|
+ # 使用csv_to_raster_workflow直接生成栅格
|
|
|
+ try:
|
|
|
+ workflow_result = csv_to_raster_workflow(
|
|
|
+ csv_file=final_data_file,
|
|
|
+ template_tif=config.ANALYSIS_CONFIG["template_tif"],
|
|
|
+ output_dir=config.OUTPUT_PATHS["raster_dir"],
|
|
|
+ boundary_shp=config.ANALYSIS_CONFIG.get("boundary_shp"),
|
|
|
+ resolution_factor=1.0, # 高分辨率栅格生成
|
|
|
+ interpolation_method='nearest',
|
|
|
+ field_name='Prediction',
|
|
|
+ lon_col=0,
|
|
|
+ lat_col=1,
|
|
|
+ value_col=2,
|
|
|
+ enable_interpolation=False # 禁用空间插值
|
|
|
+ )
|
|
|
+
|
|
|
+ output_raster = workflow_result['raster']
|
|
|
+
|
|
|
+ # 重命名为特定模型的文件
|
|
|
+ model_raster_path = os.path.join(
|
|
|
config.OUTPUT_PATHS["raster_dir"],
|
|
|
- f"points_{model_name}.shp"
|
|
|
+ f"output_{model_name}.tif"
|
|
|
)
|
|
|
- )
|
|
|
- model_outputs[model_name]['raster'] = output_raster
|
|
|
- logger.info(f"{display_name}模型栅格文件生成完成: {output_raster}")
|
|
|
+
|
|
|
+ if os.path.exists(output_raster) and output_raster != model_raster_path:
|
|
|
+ import shutil
|
|
|
+ shutil.move(output_raster, model_raster_path)
|
|
|
+ output_raster = model_raster_path
|
|
|
+
|
|
|
+ model_outputs[model_name]['raster'] = output_raster
|
|
|
+ logger.info(f"{display_name}模型栅格文件生成完成: {output_raster}")
|
|
|
+
|
|
|
+ # 清理中间shapefile文件
|
|
|
+ try:
|
|
|
+ shapefile_path = workflow_result.get('shapefile')
|
|
|
+ if shapefile_path and os.path.exists(shapefile_path):
|
|
|
+ # 删除shapefile及其相关文件
|
|
|
+ import glob
|
|
|
+ base_path = os.path.splitext(shapefile_path)[0]
|
|
|
+ for ext in ['.shp', '.shx', '.dbf', '.prj', '.cpg']:
|
|
|
+ file_to_delete = base_path + ext
|
|
|
+ if os.path.exists(file_to_delete):
|
|
|
+ os.remove(file_to_delete)
|
|
|
+ logger.debug(f"已删除中间文件: {file_to_delete}")
|
|
|
+ logger.info(f"已清理中间shapefile文件: {shapefile_path}")
|
|
|
+ except Exception as cleanup_error:
|
|
|
+ logger.warning(f"清理中间文件失败: {str(cleanup_error)}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"栅格生成失败: {str(e)}")
|
|
|
+ raise
|
|
|
|
|
|
- # 步骤4b: 创建可视化地图
|
|
|
+ # 步骤4b: 创建可视化地图(直接调用通用绘图模块)
|
|
|
if config.WORKFLOW_CONFIG["create_visualization"]:
|
|
|
logger.info(f"步骤4b: 为{display_name}模型创建可视化地图...")
|
|
|
- visualizer = Visualizer()
|
|
|
|
|
|
# 为每个模型创建独立的地图输出路径
|
|
|
map_output_path = os.path.join(
|
|
@@ -125,15 +165,29 @@ def main():
|
|
|
f"Prediction_results_{model_name}"
|
|
|
)
|
|
|
|
|
|
- map_output = visualizer.create_raster_map(
|
|
|
- tif_path=output_raster,
|
|
|
- title_name=title_name,
|
|
|
- output_path=map_output_path
|
|
|
- )
|
|
|
- model_outputs[model_name]['map'] = map_output
|
|
|
- logger.info(f"{display_name}模型地图可视化完成: {map_output}")
|
|
|
+ try:
|
|
|
+ # 直接调用通用绘图模块的create_raster_map
|
|
|
+ map_output = mapping_utils.create_raster_map(
|
|
|
+ shp_path=config.ANALYSIS_CONFIG.get("boundary_shp"),
|
|
|
+ tif_path=output_raster,
|
|
|
+ output_path=map_output_path,
|
|
|
+ title=title_name,
|
|
|
+ colormap=config.VISUALIZATION_CONFIG["color_maps"][config.VISUALIZATION_CONFIG["default_colormap"]],
|
|
|
+ output_size=config.VISUALIZATION_CONFIG["figure_size"],
|
|
|
+ dpi=config.VISUALIZATION_CONFIG["dpi"],
|
|
|
+ resolution_factor=1.0, # 保持栅格原始分辨率
|
|
|
+ enable_interpolation=False, # 不在可视化阶段插值
|
|
|
+ interpolation_method='nearest'
|
|
|
+ )
|
|
|
+
|
|
|
+ model_outputs[model_name]['map'] = map_output
|
|
|
+ logger.info(f"{display_name}模型地图可视化完成: {map_output}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"地图可视化失败: {str(e)}")
|
|
|
+ raise
|
|
|
|
|
|
- # 步骤4c: 创建直方图
|
|
|
+ # 步骤4c: 创建直方图(直接调用通用绘图模块)
|
|
|
if config.WORKFLOW_CONFIG["create_histogram"]:
|
|
|
logger.info(f"步骤4c: 为{display_name}模型创建预测值分布直方图...")
|
|
|
|
|
@@ -154,14 +208,24 @@ def main():
|
|
|
hist_title = f'{model_name} Prediction Frequency'
|
|
|
hist_xlabel = f'{model_name} Content'
|
|
|
|
|
|
- histogram_output = visualizer.create_histogram(
|
|
|
- file_path=output_raster,
|
|
|
- title=hist_title,
|
|
|
- xlabel=hist_xlabel,
|
|
|
- save_path=histogram_output_path
|
|
|
- )
|
|
|
- model_outputs[model_name]['histogram'] = histogram_output
|
|
|
- logger.info(f"{display_name}模型直方图创建完成: {histogram_output}")
|
|
|
+ try:
|
|
|
+ # 直接调用通用绘图模块的create_histogram
|
|
|
+ histogram_output = mapping_utils.create_histogram(
|
|
|
+ file_path=output_raster,
|
|
|
+ save_path=histogram_output_path,
|
|
|
+ figsize=(6, 6),
|
|
|
+ xlabel=hist_xlabel,
|
|
|
+ ylabel='Frequency',
|
|
|
+ title=hist_title,
|
|
|
+ dpi=config.VISUALIZATION_CONFIG["dpi"]
|
|
|
+ )
|
|
|
+
|
|
|
+ model_outputs[model_name]['histogram'] = histogram_output
|
|
|
+ logger.info(f"{display_name}模型直方图创建完成: {histogram_output}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"直方图生成失败: {str(e)}")
|
|
|
+ raise
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
logger.info("所有流程执行完成!")
|
|
@@ -178,7 +242,7 @@ def main():
|
|
|
def print_summary(model_outputs):
|
|
|
"""打印执行结果摘要"""
|
|
|
print("\n" + "=" * 60)
|
|
|
- print("[SUCCESS] Cd预测集成系统执行完成!")
|
|
|
+ print("[SUCCESS] Cd预测集成系统执行完成! (架构简化版本)")
|
|
|
print("=" * 60)
|
|
|
print(f"[INFO] 总输出目录: {config.OUTPUT_PATHS}")
|
|
|
print()
|
|
@@ -206,4 +270,4 @@ def print_summary(model_outputs):
|
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
+ main()
|