123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- """
- 作物Cd模型预测模块
- Crop Cd Model Prediction Module
- 基于原始model_out.py改进,使用面向对象设计
- """
- import os
- import sys
- import logging
- import numpy as np
- import pandas as pd
- import torch
- import torch.nn as nn
- from sklearn.preprocessing import StandardScaler
- # 添加项目根目录到路径
- sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
- import config
- class CropCdPredictor:
- """
- 作物Cd模型预测器
- """
-
- def __init__(self):
- """
- 初始化预测器
- """
- self.logger = logging.getLogger(__name__)
- self.model_config = config.CROP_CD_MODEL
- self.model = None
- self.scaler = None
-
- def _import_model_dependencies(self):
- """
- 导入模型依赖的自定义模块
- """
- try:
- model_files_dir = os.path.join(self.model_config["model_dir"], "model_files")
-
- # 将模型文件目录添加到Python路径
- if model_files_dir not in sys.path:
- sys.path.insert(0, model_files_dir)
-
- # 动态导入constrained_nn6模块
- import importlib.util
- constraint_module_path = os.path.join(model_files_dir, "constrained_nn6.py")
-
- spec = importlib.util.spec_from_file_location("constrained_nn6", constraint_module_path)
- constrained_nn6 = importlib.util.module_from_spec(spec)
- sys.modules['constrained_nn6'] = constrained_nn6
- spec.loader.exec_module(constrained_nn6)
-
- self.logger.info("模型依赖模块导入成功")
- return constrained_nn6
-
- except Exception as e:
- self.logger.error(f"模型依赖模块导入失败: {str(e)}")
- raise
-
- def load_model(self):
- """
- 加载训练好的模型和标准化参数
- """
- try:
- # 首先导入模型依赖的模块
- self._import_model_dependencies()
-
- # 构建模型文件路径
- model_path = os.path.join(
- self.model_config["model_dir"],
- "model_files",
- self.model_config["model_file"]
- )
-
- # 加载模型
- self.model = torch.load(model_path, map_location='cpu')
- self.model.eval()
- self.logger.info(f"模型加载成功: {model_path}")
-
- # 加载标准化参数
- mean_path = os.path.join(
- self.model_config["model_dir"],
- "model_files",
- self.model_config["mean_file"]
- )
- scale_path = os.path.join(
- self.model_config["model_dir"],
- "model_files",
- self.model_config["scale_file"]
- )
-
- self.scaler = StandardScaler()
- self.scaler.mean_ = np.load(mean_path)
- self.scaler.scale_ = np.load(scale_path)
- self.logger.info("标准化参数加载成功")
-
- except Exception as e:
- self.logger.error(f"模型加载失败: {str(e)}")
- raise
-
- def load_data(self):
- """
- 加载输入数据
-
- @return: 输入数据DataFrame
- """
- try:
- data_path = os.path.join(
- self.model_config["model_dir"],
- "data",
- self.model_config["input_data"]
- )
-
- data = pd.read_csv(data_path)
- self.logger.info(f"数据加载成功: {data_path}, 数据形状: {data.shape}")
- return data
-
- except Exception as e:
- self.logger.error(f"数据加载失败: {str(e)}")
- raise
-
- def preprocess_data(self, data):
- """
- 数据预处理
-
- @param data: 原始数据DataFrame
- @return: 预处理后的数据和特征索引
- """
- try:
- # 获取特殊特征的索引
- special_feature_index = data.columns.get_loc(self.model_config["special_feature"])
-
- # 转换为numpy数组
- X = data.values
-
- # 标准化
- X_scaled = self.scaler.transform(X)
-
- # 转换为张量
- X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
-
- self.logger.info("数据预处理完成")
- return X_tensor, special_feature_index
-
- except Exception as e:
- self.logger.error(f"数据预处理失败: {str(e)}")
- raise
-
- def predict(self):
- """
- 执行预测
-
- @return: 输出文件路径
- """
- try:
- self.logger.info("开始作物Cd模型预测...")
-
- # 加载模型
- self.load_model()
-
- # 加载数据
- data = self.load_data()
-
- # 预处理数据
- X_tensor, special_feature_index = self.preprocess_data(data)
-
- # 模型预测
- self.model.eval()
- with torch.no_grad():
- predictions = self.model(
- X_tensor,
- X_tensor,
- special_feature_index
- ).squeeze().numpy()
-
- # 保存预测结果
- predictions_df = pd.DataFrame(predictions)
- output_path = os.path.join(
- config.DATA_PATHS["predictions_dir"],
- self.model_config["output_file"]
- )
-
- # 确保输出目录存在
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
-
- predictions_df.to_csv(output_path, index=False)
-
- self.logger.info(f"作物Cd模型预测完成,结果保存至: {output_path}")
- self.logger.info(f"预测结果统计 - 数量: {len(predictions)}, 均值: {np.mean(predictions):.4f}, 标准差: {np.std(predictions):.4f}")
-
- return output_path
-
- except Exception as e:
- self.logger.error(f"预测过程失败: {str(e)}")
- raise
- if __name__ == "__main__":
- # 测试代码
- predictor = CropCdPredictor()
- output_file = predictor.predict()
- print(f"预测完成,输出文件: {output_file}")
|