""" 作物Cd模型预测模块 Crop Cd Model Prediction Module 基于原始model_out.py改进,使用面向对象设计 """ import os import sys import logging import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.preprocessing import StandardScaler # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import config class CropCdPredictor: """ 作物Cd模型预测器 """ def __init__(self): """ 初始化预测器 """ self.logger = logging.getLogger(__name__) self.model_config = config.CROP_CD_MODEL self.model = None self.scaler = None def _import_model_dependencies(self): """ 导入模型依赖的自定义模块 """ try: model_files_dir = os.path.join(self.model_config["model_dir"], "model_files") # 将模型文件目录添加到Python路径 if model_files_dir not in sys.path: sys.path.insert(0, model_files_dir) # 动态导入constrained_nn6模块 import importlib.util constraint_module_path = os.path.join(model_files_dir, "constrained_nn6.py") spec = importlib.util.spec_from_file_location("constrained_nn6", constraint_module_path) constrained_nn6 = importlib.util.module_from_spec(spec) sys.modules['constrained_nn6'] = constrained_nn6 spec.loader.exec_module(constrained_nn6) self.logger.info("模型依赖模块导入成功") return constrained_nn6 except Exception as e: self.logger.error(f"模型依赖模块导入失败: {str(e)}") raise def load_model(self): """ 加载训练好的模型和标准化参数 """ try: # 首先导入模型依赖的模块 self._import_model_dependencies() # 构建模型文件路径 model_path = os.path.join( self.model_config["model_dir"], "model_files", self.model_config["model_file"] ) # 加载模型 self.model = torch.load(model_path, map_location='cpu') self.model.eval() self.logger.info(f"模型加载成功: {model_path}") # 加载标准化参数 mean_path = os.path.join( self.model_config["model_dir"], "model_files", self.model_config["mean_file"] ) scale_path = os.path.join( self.model_config["model_dir"], "model_files", self.model_config["scale_file"] ) self.scaler = StandardScaler() self.scaler.mean_ = np.load(mean_path) self.scaler.scale_ = np.load(scale_path) self.logger.info("标准化参数加载成功") except Exception as e: self.logger.error(f"模型加载失败: {str(e)}") raise def load_data(self): """ 加载输入数据 @return: 输入数据DataFrame """ try: data_path = os.path.join( self.model_config["model_dir"], "data", self.model_config["input_data"] ) data = pd.read_csv(data_path) self.logger.info(f"数据加载成功: {data_path}, 数据形状: {data.shape}") return data except Exception as e: self.logger.error(f"数据加载失败: {str(e)}") raise def preprocess_data(self, data): """ 数据预处理 @param data: 原始数据DataFrame @return: 预处理后的数据和特征索引 """ try: # 获取特殊特征的索引 special_feature_index = data.columns.get_loc(self.model_config["special_feature"]) # 转换为numpy数组 X = data.values # 标准化 X_scaled = self.scaler.transform(X) # 转换为张量 X_tensor = torch.tensor(X_scaled, dtype=torch.float32) self.logger.info("数据预处理完成") return X_tensor, special_feature_index except Exception as e: self.logger.error(f"数据预处理失败: {str(e)}") raise def predict(self): """ 执行预测 @return: 输出文件路径 """ try: self.logger.info("开始作物Cd模型预测...") # 加载模型 self.load_model() # 加载数据 data = self.load_data() # 预处理数据 X_tensor, special_feature_index = self.preprocess_data(data) # 模型预测 self.model.eval() with torch.no_grad(): predictions = self.model( X_tensor, X_tensor, special_feature_index ).squeeze().numpy() predictions_exp = np.exp(predictions) # 保存预测结果 predictions_df = pd.DataFrame(predictions_exp) output_path = os.path.join( config.DATA_PATHS["predictions_dir"], self.model_config["output_file"] ) # 确保输出目录存在 os.makedirs(os.path.dirname(output_path), exist_ok=True) predictions_df.to_csv(output_path, index=False) self.logger.info(f"作物Cd模型预测完成,结果保存至: {output_path}") self.logger.info(f"预测结果统计 - 数量: {len(predictions)}, 均值: {np.mean(predictions):.4f}, 标准差: {np.std(predictions):.4f}") return output_path except Exception as e: self.logger.error(f"预测过程失败: {str(e)}") raise if __name__ == "__main__": # 测试代码 predictor = CropCdPredictor() output_file = predictor.predict() print(f"预测完成,输出文件: {output_file}")