predictors.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. """
  2. Cd预测器模块
  3. @description: 自包含的预测器,直接加载和运行模型
  4. @version: 3.0.0
  5. """
  6. import os
  7. import sys
  8. import logging
  9. import numpy as np
  10. import pandas as pd
  11. import torch
  12. import torch.nn as nn
  13. from sklearn.preprocessing import StandardScaler
  14. from typing import Dict, Any, Optional, Tuple
  15. import importlib.util
  16. from .config import get_model_config, validate_model_files
  17. class CropCdPredictor:
  18. """
  19. 作物Cd模型预测器
  20. """
  21. def __init__(self):
  22. """
  23. 初始化预测器
  24. """
  25. self.logger = logging.getLogger(__name__)
  26. self.model_config = get_model_config("crop_cd")
  27. self.model = None
  28. self.scaler = None
  29. self._validate_files()
  30. def _validate_files(self):
  31. """
  32. 验证模型文件是否存在
  33. """
  34. validation_result = validate_model_files("crop_cd")
  35. missing_files = [key for key, exists in validation_result.items() if not exists]
  36. if missing_files:
  37. raise FileNotFoundError(f"作物Cd模型文件缺失: {missing_files}")
  38. self.logger.info("作物Cd模型文件验证通过")
  39. def _load_constraint_module(self):
  40. """
  41. 动态加载约束模块
  42. """
  43. try:
  44. constraint_module_path = self.model_config["constraint_module"]
  45. # 动态导入约束模块
  46. spec = importlib.util.spec_from_file_location("constrained_nn6", constraint_module_path)
  47. constrained_nn6 = importlib.util.module_from_spec(spec)
  48. sys.modules['constrained_nn6'] = constrained_nn6
  49. spec.loader.exec_module(constrained_nn6)
  50. self.logger.info("作物Cd约束模块加载成功")
  51. return constrained_nn6
  52. except Exception as e:
  53. self.logger.error(f"作物Cd约束模块加载失败: {str(e)}")
  54. raise
  55. def load_model(self):
  56. """
  57. 加载训练好的模型和标准化参数
  58. """
  59. if self.model is not None:
  60. return # 已经加载过了
  61. try:
  62. # 加载约束模块
  63. self._load_constraint_module()
  64. # 加载模型
  65. model_path = self.model_config["model_file"]
  66. self.model = torch.load(model_path, map_location='cpu')
  67. self.model.eval()
  68. self.logger.info(f"作物Cd模型加载成功: {model_path}")
  69. # 加载标准化参数
  70. mean_path = self.model_config["mean_file"]
  71. scale_path = self.model_config["scale_file"]
  72. self.scaler = StandardScaler()
  73. self.scaler.mean_ = np.load(mean_path)
  74. self.scaler.scale_ = np.load(scale_path)
  75. self.logger.info("作物Cd标准化参数加载成功")
  76. except Exception as e:
  77. self.logger.error(f"作物Cd模型加载失败: {str(e)}")
  78. raise
  79. def predict(self, data: pd.DataFrame) -> np.ndarray:
  80. """
  81. 执行预测
  82. @param {pd.DataFrame} data - 输入数据(环境因子)
  83. @returns {np.ndarray} 预测结果
  84. """
  85. try:
  86. self.logger.info("开始作物Cd模型预测...")
  87. # 确保模型已加载
  88. self.load_model()
  89. # 获取特殊特征索引
  90. special_feature_index = self._get_special_feature_index(data)
  91. # 数据预处理
  92. X = data.values
  93. X_scaled = self.scaler.transform(X)
  94. X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
  95. # 模型预测
  96. with torch.no_grad():
  97. predictions = self.model(
  98. X_tensor,
  99. X_tensor,
  100. special_feature_index
  101. ).squeeze().numpy()
  102. # 指数变换
  103. predictions_exp = np.exp(predictions)
  104. self.logger.info(f"作物Cd预测完成,数据点数: {len(predictions_exp)}")
  105. self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}")
  106. return predictions_exp
  107. except Exception as e:
  108. self.logger.error(f"作物Cd预测失败: {str(e)}")
  109. raise
  110. def _get_special_feature_index(self, data: pd.DataFrame) -> int:
  111. """
  112. 获取特殊特征的索引
  113. @param {pd.DataFrame} data - 输入数据
  114. @returns {int} 特殊特征索引
  115. """
  116. special_feature_name = self.model_config["special_feature_name"]
  117. # 首先尝试通过列名查找
  118. if special_feature_name in data.columns:
  119. index = data.columns.get_loc(special_feature_name)
  120. self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}")
  121. return index
  122. # 如果找不到,使用默认索引
  123. default_index = self.model_config["special_feature_index"]
  124. if default_index < data.shape[1]:
  125. self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}")
  126. return default_index
  127. # 如果默认索引也超出范围,使用最后一列
  128. last_index = data.shape[1] - 1
  129. self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}")
  130. return last_index
  131. class EffectiveCdPredictor:
  132. """
  133. 有效态Cd模型预测器
  134. """
  135. def __init__(self):
  136. """
  137. 初始化预测器
  138. """
  139. self.logger = logging.getLogger(__name__)
  140. self.model_config = get_model_config("effective_cd")
  141. self.model = None
  142. self.scaler = None
  143. self._validate_files()
  144. def _validate_files(self):
  145. """
  146. 验证模型文件是否存在
  147. """
  148. validation_result = validate_model_files("effective_cd")
  149. missing_files = [key for key, exists in validation_result.items() if not exists]
  150. if missing_files:
  151. raise FileNotFoundError(f"有效态Cd模型文件缺失: {missing_files}")
  152. self.logger.info("有效态Cd模型文件验证通过")
  153. def _load_constraint_module(self):
  154. """
  155. 动态加载约束模块
  156. """
  157. try:
  158. constraint_module_path = self.model_config["constraint_module"]
  159. # 动态导入约束模块
  160. spec = importlib.util.spec_from_file_location("constrained_nn6C", constraint_module_path)
  161. constrained_nn6C = importlib.util.module_from_spec(spec)
  162. sys.modules['constrained_nn6C'] = constrained_nn6C
  163. spec.loader.exec_module(constrained_nn6C)
  164. self.logger.info("有效态Cd约束模块加载成功")
  165. return constrained_nn6C
  166. except Exception as e:
  167. self.logger.error(f"有效态Cd约束模块加载失败: {str(e)}")
  168. raise
  169. def load_model(self):
  170. """
  171. 加载训练好的模型和标准化参数
  172. """
  173. if self.model is not None:
  174. return # 已经加载过了
  175. try:
  176. # 加载约束模块
  177. self._load_constraint_module()
  178. # 加载模型
  179. model_path = self.model_config["model_file"]
  180. self.model = torch.load(model_path, map_location='cpu')
  181. self.model.eval()
  182. self.logger.info(f"有效态Cd模型加载成功: {model_path}")
  183. # 加载标准化参数
  184. mean_path = self.model_config["mean_file"]
  185. scale_path = self.model_config["scale_file"]
  186. self.scaler = StandardScaler()
  187. self.scaler.mean_ = np.load(mean_path)
  188. self.scaler.scale_ = np.load(scale_path)
  189. self.logger.info("有效态Cd标准化参数加载成功")
  190. except Exception as e:
  191. self.logger.error(f"有效态Cd模型加载失败: {str(e)}")
  192. raise
  193. def predict(self, data: pd.DataFrame) -> np.ndarray:
  194. """
  195. 执行预测
  196. @param {pd.DataFrame} data - 输入数据(环境因子)
  197. @returns {np.ndarray} 预测结果
  198. """
  199. try:
  200. self.logger.info("开始有效态Cd模型预测...")
  201. # 确保模型已加载
  202. self.load_model()
  203. # 检查数据列数(有效态Cd模型通常需要21列特征)
  204. if data.shape[1] < 21:
  205. self.logger.warning(f"有效态Cd模型期望21列特征,当前只有{data.shape[1]}列")
  206. elif data.shape[1] > 21:
  207. self.logger.info(f"输入数据有{data.shape[1]}列,取前21列用于有效态Cd预测")
  208. data = data.iloc[:, :21]
  209. # 获取特殊特征索引
  210. special_feature_index = self._get_special_feature_index(data)
  211. # 数据预处理
  212. X = data.values
  213. X_scaled = self.scaler.transform(X)
  214. X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
  215. # 模型预测
  216. with torch.no_grad():
  217. predictions = self.model(
  218. X_tensor,
  219. X_tensor,
  220. special_feature_index
  221. ).squeeze().numpy()
  222. # 指数变换
  223. predictions_exp = np.exp(predictions)
  224. self.logger.info(f"有效态Cd预测完成,数据点数: {len(predictions_exp)}")
  225. self.logger.info(f"预测统计 - 均值: {np.mean(predictions_exp):.4f}, 标准差: {np.std(predictions_exp):.4f}")
  226. return predictions_exp
  227. except Exception as e:
  228. self.logger.error(f"有效态Cd预测失败: {str(e)}")
  229. raise
  230. def _get_special_feature_index(self, data: pd.DataFrame) -> int:
  231. """
  232. 获取特殊特征的索引
  233. @param {pd.DataFrame} data - 输入数据
  234. @returns {int} 特殊特征索引
  235. """
  236. special_feature_name = self.model_config["special_feature_name"]
  237. # 首先尝试通过列名查找
  238. if special_feature_name in data.columns:
  239. index = data.columns.get_loc(special_feature_name)
  240. self.logger.info(f"找到特殊特征列 '{special_feature_name}',索引: {index}")
  241. return index
  242. # 如果找不到,使用默认索引
  243. default_index = self.model_config["special_feature_index"]
  244. if default_index < data.shape[1]:
  245. self.logger.warning(f"未找到特殊特征列 '{special_feature_name}',使用默认索引: {default_index}")
  246. return default_index
  247. # 如果默认索引也超出范围,使用最后一列
  248. last_index = data.shape[1] - 1
  249. self.logger.warning(f"默认索引超出范围,使用最后一列索引: {last_index}")
  250. return last_index
  251. class DataProcessor:
  252. """
  253. 数据处理器 - 处理预测结果与坐标的合并
  254. """
  255. def __init__(self):
  256. """
  257. 初始化数据处理器
  258. """
  259. self.logger = logging.getLogger(__name__)
  260. def combine_predictions_with_coordinates(self, coordinates: pd.DataFrame,
  261. predictions: np.ndarray) -> pd.DataFrame:
  262. """
  263. 将预测结果与坐标数据合并
  264. @param {pd.DataFrame} coordinates - 坐标数据(包含longitude, latitude列)
  265. @param {np.ndarray} predictions - 预测结果
  266. @returns {pd.DataFrame} 合并后的数据
  267. """
  268. try:
  269. if len(coordinates) != len(predictions):
  270. raise ValueError(f"坐标数据点数({len(coordinates)})与预测结果数量({len(predictions)})不匹配")
  271. # 创建结果数据框
  272. result_df = coordinates[['longitude', 'latitude']].copy()
  273. result_df['Prediction'] = predictions
  274. self.logger.info(f"预测结果与坐标合并完成,数据形状: {result_df.shape}")
  275. return result_df
  276. except Exception as e:
  277. self.logger.error(f"数据合并失败: {str(e)}")
  278. raise
  279. def validate_final_data(self, data: pd.DataFrame) -> Dict[str, Any]:
  280. """
  281. 验证最终数据的格式和内容
  282. @param {pd.DataFrame} data - 最终数据
  283. @returns {Dict[str, Any]} 验证结果
  284. """
  285. try:
  286. validation_result = {
  287. "valid": True,
  288. "errors": [],
  289. "warnings": [],
  290. "statistics": {}
  291. }
  292. # 检查必要的列
  293. required_columns = ['longitude', 'latitude', 'Prediction']
  294. missing_columns = [col for col in required_columns if col not in data.columns]
  295. if missing_columns:
  296. validation_result["valid"] = False
  297. validation_result["errors"].append(f"缺少必要的列: {missing_columns}")
  298. # 检查数据完整性
  299. if data.isnull().any().any():
  300. validation_result["warnings"].append("数据中存在空值")
  301. # 检查坐标范围
  302. if 'longitude' in data.columns and 'latitude' in data.columns:
  303. if not (data['longitude'].between(-180, 180).all() and
  304. data['latitude'].between(-90, 90).all()):
  305. validation_result["warnings"].append("坐标值超出合理范围")
  306. # 计算预测值统计信息
  307. if 'Prediction' in data.columns:
  308. pred_series = data['Prediction']
  309. validation_result["statistics"] = {
  310. "count": len(pred_series),
  311. "mean": float(pred_series.mean()),
  312. "std": float(pred_series.std()),
  313. "min": float(pred_series.min()),
  314. "max": float(pred_series.max()),
  315. "median": float(pred_series.median())
  316. }
  317. if validation_result["valid"]:
  318. self.logger.info("数据验证通过")
  319. else:
  320. self.logger.error(f"数据验证失败: {validation_result['errors']}")
  321. return validation_result
  322. except Exception as e:
  323. self.logger.error(f"数据验证失败: {str(e)}")
  324. return {
  325. "valid": False,
  326. "errors": [f"验证过程出错: {str(e)}"],
  327. "warnings": [],
  328. "statistics": {}
  329. }