|
- """
- Cd预测引擎 v3.0
- @description: 完全自包含的预测引擎,不依赖外部集成系统
- @version: 3.0.0
- """
- import os
- import logging
- import tempfile
- import shutil
- from datetime import datetime
- from typing import Dict, Any, Optional, Tuple
- import pandas as pd
- import numpy as np
- from .predictors import CropCdPredictor, EffectiveCdPredictor, DataProcessor
- from .config import get_raster_config, get_template_tif_path, VISUALIZATION_CONFIG, ensure_directories
- from ...utils.mapping_utils import dataframe_to_raster_workflow, MappingUtils
- class CdPredictionEngine:
- """
- Cd预测引擎 v3.0 - 完全自包含版本
- """
-
- def __init__(self, output_base_dir: str):
- """
- 初始化预测引擎
-
- @param {str} output_base_dir - 输出基础目录
- """
- self.output_base_dir = output_base_dir
- self.logger = logging.getLogger(__name__)
-
- # 确保输出目录存在
- ensure_directories(output_base_dir)
-
- # 设置输出路径
- self.output_paths = {
- "figures": os.path.join(output_base_dir, "figures"),
- "raster": os.path.join(output_base_dir, "raster"),
- "data": os.path.join(output_base_dir, "data"),
- "temp": os.path.join(output_base_dir, "data", "temp"),
- "final": os.path.join(output_base_dir, "data", "final")
- }
-
- # 初始化预测器(懒加载)
- self._crop_predictor = None
- self._effective_predictor = None
- self._data_processor = None
- self._mapping_utils = None
-
- self.logger.info(f"Cd预测引擎v3.0初始化完成,输出目录: {output_base_dir}")
-
- @property
- def crop_predictor(self) -> CropCdPredictor:
- """获取作物Cd预测器(懒加载)"""
- if self._crop_predictor is None:
- self._crop_predictor = CropCdPredictor()
- return self._crop_predictor
-
- @property
- def effective_predictor(self) -> EffectiveCdPredictor:
- """获取有效态Cd预测器(懒加载)"""
- if self._effective_predictor is None:
- self._effective_predictor = EffectiveCdPredictor()
- return self._effective_predictor
-
- @property
- def data_processor(self) -> DataProcessor:
- """获取数据处理器(懒加载)"""
- if self._data_processor is None:
- self._data_processor = DataProcessor()
- return self._data_processor
-
- @property
- def mapping_utils(self) -> MappingUtils:
- """获取地图工具(懒加载)"""
- if self._mapping_utils is None:
- self._mapping_utils = MappingUtils()
- return self._mapping_utils
-
- def predict_crop_cd(self, environmental_data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
- """
- 执行作物Cd预测
-
- @param {pd.DataFrame} environmental_data - 环境因子数据
- @returns {Tuple[np.ndarray, pd.DataFrame]} 预测结果和验证信息
- """
- try:
- self.logger.info("开始作物Cd预测...")
-
- # 执行预测
- predictions = self.crop_predictor.predict(environmental_data)
-
- # 验证结果
- temp_df = pd.DataFrame({
- 'longitude': [0] * len(predictions), # 临时坐标
- 'latitude': [0] * len(predictions),
- 'Prediction': predictions
- })
- validation_result = self.data_processor.validate_final_data(temp_df)
-
- self.logger.info("作物Cd预测完成")
- return predictions, validation_result
-
- except Exception as e:
- self.logger.error(f"作物Cd预测失败: {str(e)}")
- raise
-
- def predict_effective_cd(self, environmental_data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
- """
- 执行有效态Cd预测
-
- @param {pd.DataFrame} environmental_data - 环境因子数据
- @returns {Tuple[np.ndarray, pd.DataFrame]} 预测结果和验证信息
- """
- try:
- self.logger.info("开始有效态Cd预测...")
-
- # 执行预测
- predictions = self.effective_predictor.predict(environmental_data)
-
- # 验证结果
- temp_df = pd.DataFrame({
- 'longitude': [0] * len(predictions), # 临时坐标
- 'latitude': [0] * len(predictions),
- 'Prediction': predictions
- })
- validation_result = self.data_processor.validate_final_data(temp_df)
-
- self.logger.info("有效态Cd预测完成")
- return predictions, validation_result
-
- except Exception as e:
- self.logger.error(f"有效态Cd预测失败: {str(e)}")
- raise
-
- def create_final_dataset(self, coordinates: pd.DataFrame, predictions: np.ndarray,
- model_type: str) -> str:
- """
- 创建最终数据集
-
- @param {pd.DataFrame} coordinates - 坐标数据
- @param {np.ndarray} predictions - 预测结果
- @param {str} model_type - 模型类型
- @returns {str} 最终数据文件路径
- """
- try:
- # 合并数据
- final_data = self.data_processor.combine_predictions_with_coordinates(
- coordinates, predictions
- )
-
- # 保存最终数据
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- filename = f"Final_predictions_{model_type}_{timestamp}.csv"
- final_path = os.path.join(self.output_paths["final"], filename)
-
- final_data.to_csv(final_path, index=False, encoding='utf-8-sig')
-
- self.logger.info(f"最终数据集已保存: {final_path}")
- return final_path
-
- except Exception as e:
- self.logger.error(f"创建最终数据集失败: {str(e)}")
- raise
-
- def create_visualization(self, final_data_df: pd.DataFrame, model_type: str,
- county_name: str, boundary_gdf=None,
- raster_config_override: Optional[Dict[str, Any]] = None,
- save_raster: bool = False) -> Dict[str, str]:
- """
- 创建可视化图表
-
- @param {pd.DataFrame} final_data_df - 最终数据DataFrame
- @param {str} model_type - 模型类型
- @param {str} county_name - 县市名称
- @param boundary_gdf - 边界GeoDataFrame
- @param {Optional[Dict[str, Any]]} raster_config_override - 栅格配置覆盖
- @param {bool} save_raster - 是否保存栅格文件,默认False
- @returns {Dict[str, str]} 生成的文件路径
- """
- try:
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
-
- # 获取栅格配置
- raster_config = get_raster_config(raster_config_override)
-
- # 生成栅格数据(可选择是否保存到文件)
- if save_raster:
- self.logger.info("开始生成栅格文件(保存到磁盘)...")
- workflow_result = dataframe_to_raster_workflow(
- df=final_data_df,
- template_tif=get_template_tif_path(),
- output_dir=self.output_paths["raster"],
- boundary_gdf=boundary_gdf,
- resolution_factor=raster_config['resolution_factor'],
- interpolation_method=raster_config['interpolation_method'],
- field_name=raster_config['field_name'],
- lon_col=raster_config['coordinate_columns']['longitude'],
- lat_col=raster_config['coordinate_columns']['latitude'],
- value_col=raster_config['coordinate_columns']['value'],
- enable_interpolation=raster_config['enable_interpolation']
- )
-
- raster_path = workflow_result['raster']
- # 重命名栅格文件
- final_raster_name = f"prediction_{model_type}_{county_name}_{timestamp}.tif"
- final_raster_path = os.path.join(self.output_paths["raster"], final_raster_name)
- if raster_path != final_raster_path:
- shutil.move(raster_path, final_raster_path)
- else:
- self.logger.info("生成临时栅格数据(仅用于可视化,不保存文件)...")
- # 使用临时目录生成栅格,用于可视化后删除
- import tempfile
- temp_dir = tempfile.mkdtemp()
- try:
- workflow_result = dataframe_to_raster_workflow(
- df=final_data_df,
- template_tif=get_template_tif_path(),
- output_dir=temp_dir,
- boundary_gdf=boundary_gdf,
- resolution_factor=raster_config['resolution_factor'],
- interpolation_method=raster_config['interpolation_method'],
- field_name=raster_config['field_name'],
- lon_col=raster_config['coordinate_columns']['longitude'],
- lat_col=raster_config['coordinate_columns']['latitude'],
- value_col=raster_config['coordinate_columns']['value'],
- enable_interpolation=raster_config['enable_interpolation']
- )
- final_raster_path = workflow_result['raster'] # 临时栅格文件路径
- except Exception as e:
- # 清理临时目录
- shutil.rmtree(temp_dir, ignore_errors=True)
- raise e
-
- # 生成地图可视化
- self.logger.info("开始生成地图可视化...")
- map_title = self._get_map_title(model_type)
- map_filename = f"prediction_map_{model_type}_{county_name}_{timestamp}" # 不包含扩展名
- map_path = os.path.join(self.output_paths["figures"], map_filename)
-
- map_result = self.mapping_utils.create_raster_map(
- shp_path=None, # 不使用shapefile路径
- tif_path=final_raster_path,
- output_path=map_path,
- title=map_title,
- colormap=VISUALIZATION_CONFIG['default_colormap'],
- figsize=VISUALIZATION_CONFIG['figure_size'], # 使用figsize而不是output_size
- dpi=VISUALIZATION_CONFIG['dpi'],
- resolution_factor=1.0,
- enable_interpolation=False,
- interpolation_method='nearest',
- boundary_gdf=boundary_gdf # 使用GeoDataFrame边界
- )
-
- # 生成直方图
- self.logger.info("开始生成直方图...")
- hist_title, hist_xlabel = self._get_histogram_labels(model_type)
- hist_filename = f"prediction_histogram_{model_type}_{county_name}_{timestamp}.jpg"
- hist_path = os.path.join(self.output_paths["figures"], hist_filename)
-
- hist_result = self.mapping_utils.create_histogram(
- file_path=final_raster_path,
- save_path=hist_path,
- figsize=(6, 6),
- xlabel=hist_xlabel,
- ylabel='Frequency',
- title=hist_title,
- dpi=VISUALIZATION_CONFIG['dpi']
- )
-
- # 清理临时栅格文件(如果不保存栅格)
- if not save_raster and 'temp_dir' in locals():
- try:
- shutil.rmtree(temp_dir, ignore_errors=True)
- self.logger.info("临时栅格文件已清理")
- final_raster_path = None # 不返回栅格路径
- except Exception as cleanup_err:
- self.logger.warning(f"清理临时文件失败: {str(cleanup_err)}")
-
- result = {
- 'raster': final_raster_path if save_raster else None,
- 'map': map_result,
- 'histogram': hist_result
- }
-
- self.logger.info("可视化创建完成")
- return result
-
- except Exception as e:
- self.logger.error(f"创建可视化失败: {str(e)}")
- raise
-
- def predict_and_visualize(self, input_data: pd.DataFrame, model_type: str,
- county_name: str, boundary_gdf=None,
- raster_config_override: Optional[Dict[str, Any]] = None,
- save_raster: bool = False) -> Dict[str, Any]:
- """
- 完整的预测和可视化流程
-
- @param {pd.DataFrame} input_data - 输入数据(前两列为经纬度,后续列为环境因子)
- @param {str} model_type - 模型类型 ("crop_cd" 或 "effective_cd")
- @param {str} county_name - 县市名称
- @param boundary_gdf - 边界GeoDataFrame(可选)
- @param {Optional[Dict[str, Any]]} raster_config_override - 栅格配置覆盖
- @param {bool} save_raster - 是否保存栅格文件,默认False(仅生成地图和直方图)
- @returns {Dict[str, Any]} 完整结果
- """
- try:
- self.logger.info(f"开始{model_type}模型的完整预测流程(使用统一绘图接口)...")
-
- # 分离坐标和环境因子数据
- coordinates = input_data.iloc[:, :2].copy()
- coordinates.columns = ['longitude', 'latitude']
- environmental_data = input_data.iloc[:, 2:].copy()
-
- # 执行预测
- if model_type == "crop_cd":
- predictions, validation = self.predict_crop_cd(environmental_data)
- elif model_type == "effective_cd":
- predictions, validation = self.predict_effective_cd(environmental_data)
- else:
- raise ValueError(f"不支持的模型类型: {model_type}")
-
- # 合并坐标和预测结果为最终数据DataFrame
- final_data_df = self.data_processor.combine_predictions_with_coordinates(
- coordinates, predictions
- )
-
- # 保存最终数据文件(可选,为了兼容性)
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- filename = f"Final_predictions_{model_type}_{timestamp}.csv"
- final_data_file = os.path.join(self.output_paths["final"], filename)
- final_data_df.to_csv(final_data_file, index=False, encoding='utf-8-sig')
-
- # 创建可视化 - 直接使用DataFrame,无需临时文件
- visualization_result = self.create_visualization(
- final_data_df, model_type, county_name, boundary_gdf, raster_config_override, save_raster
- )
-
- # 返回完整结果
- result = {
- 'model_type': model_type,
- 'county_name': county_name,
- 'final_data_file': final_data_file,
- 'final_data_df': final_data_df, # 新增:返回DataFrame
- 'raster_path': visualization_result['raster'],
- 'map_path': visualization_result['map'],
- 'histogram_path': visualization_result['histogram'],
- 'validation': validation,
- 'timestamp': datetime.now().isoformat()
- }
-
- self.logger.info(f"{model_type}模型完整预测流程完成")
- return result
-
- except Exception as e:
- self.logger.error(f"{model_type}模型完整预测流程失败: {str(e)}")
- raise
-
- def _get_map_title(self, model_type: str) -> str:
- """获取地图标题"""
- titles = {
- "crop_cd": "Crop Cd Prediction",
- "effective_cd": "Effective Cd Prediction"
- }
- return titles.get(model_type, f"{model_type} Prediction")
-
- def _get_histogram_labels(self, model_type: str) -> Tuple[str, str]:
- """获取直方图标签"""
- labels = {
- "crop_cd": ("Crop Cd Prediction Frequency", "Crop Cd Content (mg/kg)"),
- "effective_cd": ("Effective Cd Prediction Frequency", "Effective Cd Content (mg/kg)")
- }
- return labels.get(model_type, (f"{model_type} Prediction Frequency", f"{model_type} Content"))
-
- def cleanup_temp_files(self):
- """清理临时文件"""
- try:
- temp_dir = self.output_paths["temp"]
- if os.path.exists(temp_dir):
- for file in os.listdir(temp_dir):
- file_path = os.path.join(temp_dir, file)
- if os.path.isfile(file_path):
- os.remove(file_path)
- self.logger.debug(f"已删除临时文件: {file}")
-
- self.logger.info("临时文件清理完成")
-
- except Exception as e:
- self.logger.warning(f"清理临时文件失败: {str(e)}")
-
- def get_model_info(self) -> Dict[str, Any]:
- """
- 获取模型信息
-
- @returns {Dict[str, Any]} 模型信息
- """
- from .config import validate_model_files
-
- return {
- "version": "3.0.0",
- "output_base_dir": self.output_base_dir,
- "output_paths": self.output_paths,
- "crop_cd_files": validate_model_files("crop_cd"),
- "effective_cd_files": validate_model_files("effective_cd"),
- "template_tif": get_template_tif_path(),
- "template_tif_exists": os.path.exists(get_template_tif_path())
- }
|