""" Cd预测器模块 @description: 自包含的预测器,直接加载和运行模型 @version: 3.0.0 """ import os import sys import logging import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.preprocessing import StandardScaler from typing import Dict, Any, Optional, Tuple import importlib.util from .config import get_model_config, validate_model_files class CropCdPredictor: """ 作物Cd模型预测器 """ def __init__(self): """ 初始化预测器 """ self.logger = logging.getLogger(__name__) self.model_config = get_model_config("crop_cd") self.model = None self.scaler = None self._validate_files() def _validate_files(self): """ 验证模型文件是否存在 """ validation_result = validate_model_files("crop_cd") missing_files = [key for key, exists in validation_result.items() if not exists] if missing_files: raise FileNotFoundError(f"作物Cd模型文件缺失: {missing_files}") self.logger.info("作物Cd模型文件验证通过") def _load_constraint_module(self): """ 动态加载约束模块 """ try: constraint_module_path = self.model_config["constraint_module"] # 动态导入约束模块 spec = importlib.util.spec_from_file_location("constrained_nn6", constraint_module_path) constrained_nn6 = importlib.util.module_from_spec(spec) sys.modules['constrained_nn6'] = constrained_nn6 spec.loader.exec_module(constrained_nn6) self.logger.info("作物Cd约束模块加载成功") return constrained_nn6 except Exception as e: self.logger.error(f"作物Cd约束模块加载失败: {str(e)}") raise def load_model(self): """ 加载训练好的模型和标准化参数 """ if self.model is not None: return # 已经加载过了 try: # 加载约束模块 self._load_constraint_module() # 加载模型 model_path = self.model_config["model_file"] self.model = torch.load(model_path, map_location='cpu') self.model.eval() self.logger.info(f"作物Cd模型加载成功: {model_path}") # 加载标准化参数 mean_path = self.model_config["mean_file"] scale_path = self.model_config["scale_file"] self.scaler = StandardScaler() self.scaler.mean_ = np.load(mean_path) self.scaler.scale_ = np.load(scale_path) self.logger.info("作物Cd标准化参数加载成功") except Exception as e: self.logger.error(f"作物Cd模型加载失败: {str(e)}") raise def predict(self, data: pd.DataFrame) -> np.ndarray: """ 执行预测 @param {pd.DataFrame} data - 输入数据(环境因子) @returns {np.ndarray} 预测结果 """ try: self.logger.info("开始作物Cd模型预测...") # 确保模型已加载 self.load_model() # 获取特殊特征索引 special_feature_index = self._get_special_feature_index(data) # 数据预处理 X = data.values X_scaled = self.scaler.transform(X) X_tensor = torch.tensor(X_scaled, dtype=torch.float32) # 模型预测 with torch.no_grad(): predictions = self.model( X_tensor, X_tensor, special_feature_index ).squeeze().numpy() # 指数变换 predictions_exp = np.exp(predictions) self.logger.info(f"作物Cd预测完成,数据点数: {len(predictions_exp)}") self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}") return predictions_exp except Exception as e: self.logger.error(f"作物Cd预测失败: {str(e)}") raise def _get_special_feature_index(self, data: pd.DataFrame) -> int: """ 获取特殊特征的索引 @param {pd.DataFrame} data - 输入数据 @returns {int} 特殊特征索引 """ special_feature_name = self.model_config["special_feature_name"] # 首先尝试通过列名查找 if special_feature_name in data.columns: index = data.columns.get_loc(special_feature_name) self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}") return index # 如果找不到,使用默认索引 default_index = self.model_config["special_feature_index"] if default_index < data.shape[1]: self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}") return default_index # 如果默认索引也超出范围,使用最后一列 last_index = data.shape[1] - 1 self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}") return last_index class EffectiveCdPredictor: """ 有效态Cd模型预测器 """ def __init__(self): """ 初始化预测器 """ self.logger = logging.getLogger(__name__) self.model_config = get_model_config("effective_cd") self.model = None self.scaler = None self._validate_files() def _validate_files(self): """ 验证模型文件是否存在 """ validation_result = validate_model_files("effective_cd") missing_files = [key for key, exists in validation_result.items() if not exists] if missing_files: raise FileNotFoundError(f"有效态Cd模型文件缺失: {missing_files}") self.logger.info("有效态Cd模型文件验证通过") def _load_constraint_module(self): """ 动态加载约束模块 """ try: constraint_module_path = self.model_config["constraint_module"] # 动态导入约束模块 spec = importlib.util.spec_from_file_location("constrained_nn6C", constraint_module_path) constrained_nn6C = importlib.util.module_from_spec(spec) sys.modules['constrained_nn6C'] = constrained_nn6C spec.loader.exec_module(constrained_nn6C) self.logger.info("有效态Cd约束模块加载成功") return constrained_nn6C except Exception as e: self.logger.error(f"有效态Cd约束模块加载失败: {str(e)}") raise def load_model(self): """ 加载训练好的模型和标准化参数 """ if self.model is not None: return # 已经加载过了 try: # 加载约束模块 self._load_constraint_module() # 加载模型 model_path = self.model_config["model_file"] self.model = torch.load(model_path, map_location='cpu') self.model.eval() self.logger.info(f"有效态Cd模型加载成功: {model_path}") # 加载标准化参数 mean_path = self.model_config["mean_file"] scale_path = self.model_config["scale_file"] self.scaler = StandardScaler() self.scaler.mean_ = np.load(mean_path) self.scaler.scale_ = np.load(scale_path) self.logger.info("有效态Cd标准化参数加载成功") except Exception as e: self.logger.error(f"有效态Cd模型加载失败: {str(e)}") raise def predict(self, data: pd.DataFrame) -> np.ndarray: """ 执行预测 @param {pd.DataFrame} data - 输入数据(环境因子) @returns {np.ndarray} 预测结果 """ try: self.logger.info("开始有效态Cd模型预测...") # 确保模型已加载 self.load_model() # 检查数据列数(有效态Cd模型通常需要21列特征) if data.shape[1] < 21: self.logger.warning(f"有效态Cd模型期望21列特征,当前只有{data.shape[1]}列") elif data.shape[1] > 21: self.logger.info(f"输入数据有{data.shape[1]}列,取前21列用于有效态Cd预测") data = data.iloc[:, :21] # 获取特殊特征索引 special_feature_index = self._get_special_feature_index(data) # 数据预处理 X = data.values X_scaled = self.scaler.transform(X) X_tensor = torch.tensor(X_scaled, dtype=torch.float32) # 模型预测 with torch.no_grad(): predictions = self.model( X_tensor, X_tensor, special_feature_index ).squeeze().numpy() # 指数变换 predictions_exp = np.exp(predictions) self.logger.info(f"有效态Cd预测完成,数据点数: {len(predictions_exp)}") self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}") return predictions_exp except Exception as e: self.logger.error(f"有效态Cd预测失败: {str(e)}") raise def _get_special_feature_index(self, data: pd.DataFrame) -> int: """ 获取特殊特征的索引 @param {pd.DataFrame} data - 输入数据 @returns {int} 特殊特征索引 """ special_feature_name = self.model_config["special_feature_name"] # 首先尝试通过列名查找 if special_feature_name in data.columns: index = data.columns.get_loc(special_feature_name) self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}") return index # 如果找不到,使用默认索引 default_index = self.model_config["special_feature_index"] if default_index < data.shape[1]: self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}") return default_index # 如果默认索引也超出范围,使用最后一列 last_index = data.shape[1] - 1 self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}") return last_index class DataProcessor: """ 数据处理器 - 处理预测结果与坐标的合并 """ def __init__(self): """ 初始化数据处理器 """ self.logger = logging.getLogger(__name__) def combine_predictions_with_coordinates(self, coordinates: pd.DataFrame, predictions: np.ndarray) -> pd.DataFrame: """ 将预测结果与坐标数据合并 @param {pd.DataFrame} coordinates - 坐标数据(包含longitude, latitude列) @param {np.ndarray} predictions - 预测结果 @returns {pd.DataFrame} 合并后的数据 """ try: if len(coordinates) != len(predictions): raise ValueError(f"坐标数据点数({len(coordinates)})与预测结果数量({len(predictions)})不匹配") # 创建结果数据框 result_df = coordinates[['longitude', 'latitude']].copy() result_df['Prediction'] = predictions self.logger.info(f"预测结果与坐标合并完成,数据形状: {result_df.shape}") return result_df except Exception as e: self.logger.error(f"数据合并失败: {str(e)}") raise def validate_final_data(self, data: pd.DataFrame) -> Dict[str, Any]: """ 验证最终数据的格式和内容 @param {pd.DataFrame} data - 最终数据 @returns {Dict[str, Any]} 验证结果 """ try: validation_result = { "valid": True, "errors": [], "warnings": [], "statistics": {} } # 检查必要的列 required_columns = ['longitude', 'latitude', 'Prediction'] missing_columns = [col for col in required_columns if col not in data.columns] if missing_columns: validation_result["valid"] = False validation_result["errors"].append(f"缺少必要的列: {missing_columns}") # 检查数据完整性 if data.isnull().any().any(): validation_result["warnings"].append("数据中存在空值") # 检查坐标范围 if 'longitude' in data.columns and 'latitude' in data.columns: if not (data['longitude'].between(-180, 180).all() and data['latitude'].between(-90, 90).all()): validation_result["warnings"].append("坐标值超出合理范围") # 计算预测值统计信息 if 'Prediction' in data.columns: pred_series = data['Prediction'] validation_result["statistics"] = { "count": len(pred_series), "mean": float(pred_series.mean()), "std": float(pred_series.std()), "min": float(pred_series.min()), "max": float(pred_series.max()), "median": float(pred_series.median()) } if validation_result["valid"]: self.logger.info("数据验证通过") else: self.logger.error(f"数据验证失败: {validation_result['errors']}") return validation_result except Exception as e: self.logger.error(f"数据验证失败: {str(e)}") return { "valid": False, "errors": [f"验证过程出错: {str(e)}"], "warnings": [], "statistics": {} }