123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- """
- Cd预测器模块
- @description: 自包含的预测器,直接加载和运行模型
- @version: 3.0.0
- """
- import os
- import sys
- import logging
- import numpy as np
- import pandas as pd
- import torch
- import torch.nn as nn
- from sklearn.preprocessing import StandardScaler
- from typing import Dict, Any, Optional, Tuple
- import importlib.util
- from .config import get_model_config, validate_model_files
- class CropCdPredictor:
- """
- 作物Cd模型预测器
- """
-
- def __init__(self):
- """
- 初始化预测器
- """
- self.logger = logging.getLogger(__name__)
- self.model_config = get_model_config("crop_cd")
- self.model = None
- self.scaler = None
- self._validate_files()
-
- def _validate_files(self):
- """
- 验证模型文件是否存在
- """
- validation_result = validate_model_files("crop_cd")
- missing_files = [key for key, exists in validation_result.items() if not exists]
-
- if missing_files:
- raise FileNotFoundError(f"作物Cd模型文件缺失: {missing_files}")
-
- self.logger.info("作物Cd模型文件验证通过")
-
- def _load_constraint_module(self):
- """
- 动态加载约束模块
- """
- try:
- constraint_module_path = self.model_config["constraint_module"]
-
- # 动态导入约束模块
- spec = importlib.util.spec_from_file_location("constrained_nn6", constraint_module_path)
- constrained_nn6 = importlib.util.module_from_spec(spec)
- sys.modules['constrained_nn6'] = constrained_nn6
- spec.loader.exec_module(constrained_nn6)
-
- self.logger.info("作物Cd约束模块加载成功")
- return constrained_nn6
-
- except Exception as e:
- self.logger.error(f"作物Cd约束模块加载失败: {str(e)}")
- raise
-
- def load_model(self):
- """
- 加载训练好的模型和标准化参数
- """
- if self.model is not None:
- return # 已经加载过了
-
- try:
- # 加载约束模块
- self._load_constraint_module()
-
- # 加载模型
- model_path = self.model_config["model_file"]
- self.model = torch.load(model_path, map_location='cpu')
- self.model.eval()
- self.logger.info(f"作物Cd模型加载成功: {model_path}")
-
- # 加载标准化参数
- mean_path = self.model_config["mean_file"]
- scale_path = self.model_config["scale_file"]
-
- self.scaler = StandardScaler()
- self.scaler.mean_ = np.load(mean_path)
- self.scaler.scale_ = np.load(scale_path)
- self.logger.info("作物Cd标准化参数加载成功")
-
- except Exception as e:
- self.logger.error(f"作物Cd模型加载失败: {str(e)}")
- raise
-
- def predict(self, data: pd.DataFrame) -> np.ndarray:
- """
- 执行预测
-
- @param {pd.DataFrame} data - 输入数据(环境因子)
- @returns {np.ndarray} 预测结果
- """
- try:
- self.logger.info("开始作物Cd模型预测...")
-
- # 确保模型已加载
- self.load_model()
-
- # 获取特殊特征索引
- special_feature_index = self._get_special_feature_index(data)
-
- # 数据预处理
- X = data.values
- X_scaled = self.scaler.transform(X)
- X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
-
- # 模型预测
- with torch.no_grad():
- predictions = self.model(
- X_tensor,
- X_tensor,
- special_feature_index
- ).squeeze().numpy()
-
- # 指数变换
- predictions_exp = np.exp(predictions)
-
- self.logger.info(f"作物Cd预测完成,数据点数: {len(predictions_exp)}")
- self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}")
-
- return predictions_exp
-
- except Exception as e:
- self.logger.error(f"作物Cd预测失败: {str(e)}")
- raise
-
- def _get_special_feature_index(self, data: pd.DataFrame) -> int:
- """
- 获取特殊特征的索引
-
- @param {pd.DataFrame} data - 输入数据
- @returns {int} 特殊特征索引
- """
- special_feature_name = self.model_config["special_feature_name"]
-
- # 首先尝试通过列名查找
- if special_feature_name in data.columns:
- index = data.columns.get_loc(special_feature_name)
- self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}")
- return index
-
- # 如果找不到,使用默认索引
- default_index = self.model_config["special_feature_index"]
- if default_index < data.shape[1]:
- self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}")
- return default_index
-
- # 如果默认索引也超出范围,使用最后一列
- last_index = data.shape[1] - 1
- self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}")
- return last_index
- class EffectiveCdPredictor:
- """
- 有效态Cd模型预测器
- """
-
- def __init__(self):
- """
- 初始化预测器
- """
- self.logger = logging.getLogger(__name__)
- self.model_config = get_model_config("effective_cd")
- self.model = None
- self.scaler = None
- self._validate_files()
-
- def _validate_files(self):
- """
- 验证模型文件是否存在
- """
- validation_result = validate_model_files("effective_cd")
- missing_files = [key for key, exists in validation_result.items() if not exists]
-
- if missing_files:
- raise FileNotFoundError(f"有效态Cd模型文件缺失: {missing_files}")
-
- self.logger.info("有效态Cd模型文件验证通过")
-
- def _load_constraint_module(self):
- """
- 动态加载约束模块
- """
- try:
- constraint_module_path = self.model_config["constraint_module"]
-
- # 动态导入约束模块
- spec = importlib.util.spec_from_file_location("constrained_nn6C", constraint_module_path)
- constrained_nn6C = importlib.util.module_from_spec(spec)
- sys.modules['constrained_nn6C'] = constrained_nn6C
- spec.loader.exec_module(constrained_nn6C)
-
- self.logger.info("有效态Cd约束模块加载成功")
- return constrained_nn6C
-
- except Exception as e:
- self.logger.error(f"有效态Cd约束模块加载失败: {str(e)}")
- raise
-
- def load_model(self):
- """
- 加载训练好的模型和标准化参数
- """
- if self.model is not None:
- return # 已经加载过了
-
- try:
- # 加载约束模块
- self._load_constraint_module()
-
- # 加载模型
- model_path = self.model_config["model_file"]
- self.model = torch.load(model_path, map_location='cpu')
- self.model.eval()
- self.logger.info(f"有效态Cd模型加载成功: {model_path}")
-
- # 加载标准化参数
- mean_path = self.model_config["mean_file"]
- scale_path = self.model_config["scale_file"]
-
- self.scaler = StandardScaler()
- self.scaler.mean_ = np.load(mean_path)
- self.scaler.scale_ = np.load(scale_path)
- self.logger.info("有效态Cd标准化参数加载成功")
-
- except Exception as e:
- self.logger.error(f"有效态Cd模型加载失败: {str(e)}")
- raise
-
- def predict(self, data: pd.DataFrame) -> np.ndarray:
- """
- 执行预测
-
- @param {pd.DataFrame} data - 输入数据(环境因子)
- @returns {np.ndarray} 预测结果
- """
- try:
- self.logger.info("开始有效态Cd模型预测...")
-
- # 确保模型已加载
- self.load_model()
-
- # 检查数据列数(有效态Cd模型通常需要21列特征)
- if data.shape[1] < 21:
- self.logger.warning(f"有效态Cd模型期望21列特征,当前只有{data.shape[1]}列")
- elif data.shape[1] > 21:
- self.logger.info(f"输入数据有{data.shape[1]}列,取前21列用于有效态Cd预测")
- data = data.iloc[:, :21]
-
- # 获取特殊特征索引
- special_feature_index = self._get_special_feature_index(data)
-
- # 数据预处理
- X = data.values
- X_scaled = self.scaler.transform(X)
- X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
-
- # 模型预测
- with torch.no_grad():
- predictions = self.model(
- X_tensor,
- X_tensor,
- special_feature_index
- ).squeeze().numpy()
-
- # 指数变换
- predictions_exp = np.exp(predictions)
-
- self.logger.info(f"有效态Cd预测完成,数据点数: {len(predictions_exp)}")
- self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}")
-
- return predictions_exp
-
- except Exception as e:
- self.logger.error(f"有效态Cd预测失败: {str(e)}")
- raise
-
- def _get_special_feature_index(self, data: pd.DataFrame) -> int:
- """
- 获取特殊特征的索引
-
- @param {pd.DataFrame} data - 输入数据
- @returns {int} 特殊特征索引
- """
- special_feature_name = self.model_config["special_feature_name"]
-
- # 首先尝试通过列名查找
- if special_feature_name in data.columns:
- index = data.columns.get_loc(special_feature_name)
- self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}")
- return index
-
- # 如果找不到,使用默认索引
- default_index = self.model_config["special_feature_index"]
- if default_index < data.shape[1]:
- self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}")
- return default_index
-
- # 如果默认索引也超出范围,使用最后一列
- last_index = data.shape[1] - 1
- self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}")
- return last_index
- class DataProcessor:
- """
- 数据处理器 - 处理预测结果与坐标的合并
- """
-
- def __init__(self):
- """
- 初始化数据处理器
- """
- self.logger = logging.getLogger(__name__)
-
- def combine_predictions_with_coordinates(self, coordinates: pd.DataFrame,
- predictions: np.ndarray) -> pd.DataFrame:
- """
- 将预测结果与坐标数据合并
-
- @param {pd.DataFrame} coordinates - 坐标数据(包含longitude, latitude列)
- @param {np.ndarray} predictions - 预测结果
- @returns {pd.DataFrame} 合并后的数据
- """
- try:
- if len(coordinates) != len(predictions):
- raise ValueError(f"坐标数据点数({len(coordinates)})与预测结果数量({len(predictions)})不匹配")
-
- # 创建结果数据框
- result_df = coordinates[['longitude', 'latitude']].copy()
- result_df['Prediction'] = predictions
-
- self.logger.info(f"预测结果与坐标合并完成,数据形状: {result_df.shape}")
-
- return result_df
-
- except Exception as e:
- self.logger.error(f"数据合并失败: {str(e)}")
- raise
-
- def validate_final_data(self, data: pd.DataFrame) -> Dict[str, Any]:
- """
- 验证最终数据的格式和内容
-
- @param {pd.DataFrame} data - 最终数据
- @returns {Dict[str, Any]} 验证结果
- """
- try:
- validation_result = {
- "valid": True,
- "errors": [],
- "warnings": [],
- "statistics": {}
- }
-
- # 检查必要的列
- required_columns = ['longitude', 'latitude', 'Prediction']
- missing_columns = [col for col in required_columns if col not in data.columns]
-
- if missing_columns:
- validation_result["valid"] = False
- validation_result["errors"].append(f"缺少必要的列: {missing_columns}")
-
- # 检查数据完整性
- if data.isnull().any().any():
- validation_result["warnings"].append("数据中存在空值")
-
- # 检查坐标范围
- if 'longitude' in data.columns and 'latitude' in data.columns:
- if not (data['longitude'].between(-180, 180).all() and
- data['latitude'].between(-90, 90).all()):
- validation_result["warnings"].append("坐标值超出合理范围")
-
- # 计算预测值统计信息
- if 'Prediction' in data.columns:
- pred_series = data['Prediction']
- validation_result["statistics"] = {
- "count": len(pred_series),
- "mean": float(pred_series.mean()),
- "std": float(pred_series.std()),
- "min": float(pred_series.min()),
- "max": float(pred_series.max()),
- "median": float(pred_series.median())
- }
-
- if validation_result["valid"]:
- self.logger.info("数据验证通过")
- else:
- self.logger.error(f"数据验证失败: {validation_result['errors']}")
-
- return validation_result
-
- except Exception as e:
- self.logger.error(f"数据验证失败: {str(e)}")
- return {
- "valid": False,
- "errors": [f"验证过程出错: {str(e)}"],
- "warnings": [],
- "statistics": {}
- }
|