model_analyzer.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. # -*- coding: utf-8 -*-
  2. """
  3. 模型分析工具
  4. 此模块提供了综合的模型分析功能,包括:
  5. - 模型性能详细分析
  6. - 特征重要性分析
  7. - 预测行为分析
  8. - 参数敏感性分析
  9. - 模型诊断和可视化
  10. """
  11. import os
  12. import sys
  13. import pickle
  14. import pandas as pd
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. import seaborn as sns
  18. from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
  19. from sklearn.model_selection import cross_val_score
  20. from sklearn.inspection import permutation_importance
  21. import sqlite3
  22. from datetime import datetime
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. # 设置中文字体
  26. plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
  27. plt.rcParams['axes.unicode_minus'] = False
  28. class ModelAnalyzer:
  29. """
  30. 模型分析器类
  31. 提供全面的模型分析功能,支持模型性能评估、特征分析、预测分析等
  32. """
  33. def __init__(self, db_path='SoilAcidification.db'):
  34. """
  35. 初始化模型分析器
  36. @param {str} db_path - 数据库文件路径
  37. """
  38. self.db_path = db_path
  39. self.model_info = None
  40. self.model = None
  41. self.test_data = None
  42. self.predictions = None
  43. self.analysis_results = {}
  44. # 创建输出目录
  45. self.output_dir = 'model_optimize/analysis_results'
  46. if not os.path.exists(self.output_dir):
  47. os.makedirs(self.output_dir)
  48. def load_model_by_id(self, model_id):
  49. """
  50. 根据模型ID加载模型信息和模型对象
  51. @param {int} model_id - 模型ID
  52. @return {bool} - 是否成功加载
  53. """
  54. try:
  55. # 连接数据库
  56. conn = sqlite3.connect(self.db_path)
  57. cursor = conn.cursor()
  58. # 查询模型信息
  59. cursor.execute("""
  60. SELECT ModelID, Model_name, Model_type, ModelFilePath,
  61. Data_type, Performance_score, MAE, RMSE, CV_score,
  62. Description, Created_at
  63. FROM Models
  64. WHERE ModelID = ?
  65. """, (model_id,))
  66. result = cursor.fetchone()
  67. if not result:
  68. print(f"未找到ID为 {model_id} 的模型")
  69. return False
  70. # 存储模型信息
  71. self.model_info = {
  72. 'ModelID': result[0],
  73. 'Model_name': result[1],
  74. 'Model_type': result[2],
  75. 'ModelFilePath': result[3],
  76. 'Data_type': result[4],
  77. 'Performance_score': result[5],
  78. 'MAE': result[6],
  79. 'RMSE': result[7],
  80. 'CV_score': result[8],
  81. 'Description': result[9],
  82. 'Created_at': result[10]
  83. }
  84. conn.close()
  85. # 加载模型文件
  86. if os.path.exists(self.model_info['ModelFilePath']):
  87. with open(self.model_info['ModelFilePath'], 'rb') as f:
  88. self.model = pickle.load(f)
  89. print(f"成功加载模型: {self.model_info['Model_name']}")
  90. return True
  91. else:
  92. print(f"模型文件不存在: {self.model_info['ModelFilePath']}")
  93. return False
  94. except Exception as e:
  95. print(f"加载模型时出错: {str(e)}")
  96. return False
  97. def load_test_data(self):
  98. """
  99. 根据模型数据类型加载相应的测试数据
  100. @return {bool} - 是否成功加载测试数据
  101. """
  102. try:
  103. data_type = self.model_info['Data_type']
  104. if data_type == 'reflux':
  105. X_test_path = 'uploads/data/X_test_reflux.csv'
  106. Y_test_path = 'uploads/data/Y_test_reflux.csv'
  107. elif data_type == 'reduce':
  108. X_test_path = 'uploads/data/X_test_reduce.csv'
  109. Y_test_path = 'uploads/data/Y_test_reduce.csv'
  110. else:
  111. print(f"不支持的数据类型: {data_type}")
  112. return False
  113. if os.path.exists(X_test_path) and os.path.exists(Y_test_path):
  114. self.X_test = pd.read_csv(X_test_path)
  115. self.Y_test = pd.read_csv(Y_test_path)
  116. print(f"成功加载 {data_type} 类型的测试数据")
  117. print(f"测试集大小: X_test {self.X_test.shape}, Y_test {self.Y_test.shape}")
  118. return True
  119. else:
  120. print(f"测试数据文件不存在")
  121. return False
  122. except Exception as e:
  123. print(f"加载测试数据时出错: {str(e)}")
  124. return False
  125. def load_training_data(self):
  126. """
  127. 加载训练数据以获取参数的真实分布范围
  128. @return {bool} - 是否成功加载训练数据
  129. """
  130. try:
  131. data_type = self.model_info['Data_type']
  132. if data_type == 'reflux':
  133. # 加载current_reflux表的数据
  134. conn = sqlite3.connect(self.db_path)
  135. training_data = pd.read_sql_query("SELECT * FROM current_reflux", conn)
  136. conn.close()
  137. # 分离特征和目标变量
  138. if 'Delta_pH' in training_data.columns:
  139. self.X_train = training_data.drop(['id', 'Delta_pH'], axis=1, errors='ignore')
  140. self.y_train = training_data['Delta_pH']
  141. else:
  142. print("未找到目标变量 Delta_pH")
  143. return False
  144. elif data_type == 'reduce':
  145. # 加载current_reduce表的数据
  146. conn = sqlite3.connect(self.db_path)
  147. training_data = pd.read_sql_query("SELECT * FROM current_reduce", conn)
  148. conn.close()
  149. # 分离特征和目标变量
  150. if 'Q_over_b' in training_data.columns:
  151. self.X_train = training_data.drop(['id', 'Q_over_b'], axis=1, errors='ignore')
  152. self.y_train = training_data['Q_over_b']
  153. else:
  154. print("未找到目标变量 Q_over_b")
  155. return False
  156. else:
  157. print(f"不支持的数据类型: {data_type}")
  158. return False
  159. print(f"成功加载训练数据: {self.X_train.shape}")
  160. print(f"训练数据特征统计:")
  161. print(self.X_train.describe())
  162. return True
  163. except Exception as e:
  164. print(f"加载训练数据时出错: {str(e)}")
  165. return False
  166. def calculate_detailed_metrics(self):
  167. """
  168. 计算详细的模型性能指标
  169. @return {dict} - 包含各种性能指标的字典
  170. """
  171. try:
  172. if self.model is None or self.X_test is None:
  173. print("模型或测试数据未加载")
  174. return {}
  175. # 获取预测值
  176. y_pred = self.model.predict(self.X_test)
  177. y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test
  178. # 计算各种指标
  179. metrics = {
  180. 'R2_Score': r2_score(y_true, y_pred),
  181. 'RMSE': mean_squared_error(y_true, y_pred, squared=False),
  182. 'MAE': mean_absolute_error(y_true, y_pred),
  183. 'MSE': mean_squared_error(y_true, y_pred),
  184. 'MAPE': np.mean(np.abs((y_true - y_pred) / y_true)) * 100,
  185. 'Max_Error': np.max(np.abs(y_true - y_pred)),
  186. 'Mean_Residual': np.mean(y_true - y_pred),
  187. 'Std_Residual': np.std(y_true - y_pred)
  188. }
  189. self.predictions = {'y_true': y_true, 'y_pred': y_pred}
  190. self.analysis_results['metrics'] = metrics
  191. # 中文指标名称映射
  192. metric_names_zh = {
  193. 'R2_Score': 'R²得分',
  194. 'RMSE': '均方根误差',
  195. 'MAE': '平均绝对误差',
  196. 'MSE': '均方误差',
  197. 'MAPE': '平均绝对百分比误差',
  198. 'Max_Error': '最大误差',
  199. 'Mean_Residual': '平均残差',
  200. 'Std_Residual': '残差标准差'
  201. }
  202. print("=== 模型性能指标 ===")
  203. for key, value in metrics.items():
  204. zh_name = metric_names_zh.get(key, key)
  205. print(f"{key} ({zh_name}): {value:.4f}")
  206. return metrics
  207. except Exception as e:
  208. print(f"计算性能指标时出错: {str(e)}")
  209. return {}
  210. def analyze_feature_importance(self):
  211. """
  212. 分析特征重要性
  213. @return {dict} - 特征重要性分析结果
  214. """
  215. try:
  216. if self.model is None:
  217. print("模型未加载")
  218. return {}
  219. importance_data = {}
  220. # 1. 模型内置特征重要性(如果支持)
  221. if hasattr(self.model, 'feature_importances_'):
  222. importance_data['model_importance'] = {
  223. 'features': self.X_test.columns.tolist(),
  224. 'importance': self.model.feature_importances_.tolist()
  225. }
  226. # 2. 排列特征重要性
  227. if self.X_test is not None and self.Y_test is not None:
  228. y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test
  229. perm_importance = permutation_importance(
  230. self.model, self.X_test, y_true,
  231. n_repeats=10, random_state=42
  232. )
  233. importance_data['permutation_importance'] = {
  234. 'features': self.X_test.columns.tolist(),
  235. 'importance_mean': perm_importance.importances_mean.tolist(),
  236. 'importance_std': perm_importance.importances_std.tolist()
  237. }
  238. self.analysis_results['feature_importance'] = importance_data
  239. print("=== 特征重要性分析完成 ===")
  240. return importance_data
  241. except Exception as e:
  242. print(f"分析特征重要性时出错: {str(e)}")
  243. return {}
  244. def analyze_prediction_behavior(self, sample_params=None):
  245. """
  246. 分析模型的预测行为
  247. @param {dict} sample_params - 示例参数,如用户提供的参数
  248. @return {dict} - 预测行为分析结果
  249. """
  250. try:
  251. if self.model is None:
  252. print("模型未加载")
  253. return {}
  254. behavior_analysis = {}
  255. # 1. 分析用户提供的示例参数
  256. if sample_params:
  257. # 转换参数为DataFrame
  258. if self.model_info['Data_type'] == 'reflux':
  259. # 为reflux类型,移除target_pH
  260. sample_params_clean = sample_params.copy()
  261. sample_params_clean.pop('target_pH', None)
  262. sample_df = pd.DataFrame([sample_params_clean])
  263. else:
  264. sample_df = pd.DataFrame([sample_params])
  265. # 确保列顺序与训练数据一致
  266. if hasattr(self, 'X_test'):
  267. sample_df = sample_df.reindex(columns=self.X_test.columns, fill_value=0)
  268. prediction = self.model.predict(sample_df)[0]
  269. behavior_analysis['sample_prediction'] = {
  270. 'input_params': sample_params,
  271. 'prediction': float(prediction),
  272. 'model_type': self.model_info['Data_type']
  273. }
  274. print(f"=== 示例参数预测结果 ===")
  275. print(f"输入参数: {sample_params}")
  276. print(f"预测结果: {prediction:.4f}")
  277. # 2. 预测值分布分析
  278. if self.predictions:
  279. y_pred = self.predictions['y_pred']
  280. behavior_analysis['prediction_distribution'] = {
  281. 'mean': float(np.mean(y_pred)),
  282. 'std': float(np.std(y_pred)),
  283. 'min': float(np.min(y_pred)),
  284. 'max': float(np.max(y_pred)),
  285. 'percentiles': {
  286. '25%': float(np.percentile(y_pred, 25)),
  287. '50%': float(np.percentile(y_pred, 50)),
  288. '75%': float(np.percentile(y_pred, 75))
  289. }
  290. }
  291. self.analysis_results['prediction_behavior'] = behavior_analysis
  292. return behavior_analysis
  293. except Exception as e:
  294. print(f"分析预测行为时出错: {str(e)}")
  295. return {}
  296. def sensitivity_analysis(self, sample_params=None):
  297. """
  298. 基于训练数据分布的参数敏感性分析
  299. @param {dict} sample_params - 基准参数
  300. @return {dict} - 敏感性分析结果
  301. """
  302. try:
  303. if self.model is None or sample_params is None:
  304. print("模型未加载或未提供基准参数")
  305. return {}
  306. # 确保训练数据已加载
  307. if not hasattr(self, 'X_train') or self.X_train is None:
  308. print("训练数据未加载,尝试加载...")
  309. if not self.load_training_data():
  310. print("无法加载训练数据,使用默认参数范围")
  311. return self.sensitivity_analysis_fallback(sample_params)
  312. sensitivity_results = {}
  313. # 准备基准数据
  314. base_params = sample_params.copy()
  315. if self.model_info['Data_type'] == 'reflux':
  316. base_params.pop('target_pH', None)
  317. base_df = pd.DataFrame([base_params])
  318. if hasattr(self, 'X_test'):
  319. base_df = base_df.reindex(columns=self.X_test.columns, fill_value=0)
  320. base_prediction = self.model.predict(base_df)[0]
  321. print(f"基准预测值: {base_prediction:.6f}")
  322. # 对每个参数进行基于训练数据的敏感性分析
  323. for param_name, base_value in base_params.items():
  324. if param_name not in base_df.columns:
  325. continue
  326. # 检查参数是否在训练数据中存在
  327. if param_name not in self.X_train.columns:
  328. print(f"警告: 参数 {param_name} 不在训练数据中,跳过")
  329. continue
  330. param_sensitivity = {
  331. 'base_value': base_value,
  332. 'variations': [],
  333. 'predictions': [],
  334. 'sensitivity_score': 0,
  335. 'training_stats': {}
  336. }
  337. # 获取训练数据中该参数的统计信息
  338. param_series = self.X_train[param_name]
  339. param_stats = {
  340. 'min': float(param_series.min()),
  341. 'max': float(param_series.max()),
  342. 'mean': float(param_series.mean()),
  343. 'std': float(param_series.std()),
  344. 'q25': float(param_series.quantile(0.25)),
  345. 'q50': float(param_series.quantile(0.50)),
  346. 'q75': float(param_series.quantile(0.75)),
  347. 'q05': float(param_series.quantile(0.05)),
  348. 'q95': float(param_series.quantile(0.95))
  349. }
  350. param_sensitivity['training_stats'] = param_stats
  351. print(f"\n分析参数: {param_name} (基准值: {base_value})")
  352. print(f" 训练数据范围: [{param_stats['min']:.3f}, {param_stats['max']:.3f}]")
  353. print(f" 训练数据均值±标准差: {param_stats['mean']:.3f} ± {param_stats['std']:.3f}")
  354. print(f" 训练数据分位数 [5%, 25%, 50%, 75%, 95%]: [{param_stats['q05']:.3f}, {param_stats['q25']:.3f}, {param_stats['q50']:.3f}, {param_stats['q75']:.3f}, {param_stats['q95']:.3f}]")
  355. # 基于训练数据分布设计参数变化策略
  356. # 策略1: 使用训练数据的实际范围
  357. min_val = param_stats['min']
  358. max_val = param_stats['max']
  359. mean_val = param_stats['mean']
  360. std_val = param_stats['std']
  361. # 创建合理的参数变化点
  362. variations = []
  363. # 1. 训练数据的关键分位数点
  364. variations.extend([
  365. param_stats['min'], # 最小值
  366. param_stats['q05'], # 5%分位数
  367. param_stats['q25'], # 25%分位数
  368. param_stats['q50'], # 中位数
  369. param_stats['q75'], # 75%分位数
  370. param_stats['q95'], # 95%分位数
  371. param_stats['max'], # 最大值
  372. ])
  373. # 2. 围绕均值的标准差变化
  374. variations.extend([
  375. max(min_val, mean_val - 2*std_val), # 均值-2σ
  376. max(min_val, mean_val - std_val), # 均值-σ
  377. mean_val, # 均值
  378. min(max_val, mean_val + std_val), # 均值+σ
  379. min(max_val, mean_val + 2*std_val), # 均值+2σ
  380. ])
  381. # 3. 基准值本身
  382. variations.append(base_value)
  383. # 4. 如果基准值在合理范围内,添加围绕基准值的小幅变化
  384. if min_val <= base_value <= max_val:
  385. # 计算基准值的相对变化(不超出训练数据范围)
  386. range_size = max_val - min_val
  387. small_change = range_size * 0.05 # 5%的范围变化
  388. variations.extend([
  389. max(min_val, base_value - small_change),
  390. min(max_val, base_value + small_change),
  391. ])
  392. # 去重并排序
  393. variations = sorted(list(set(variations)))
  394. print(f" 使用参数变化点: {len(variations)} 个")
  395. for i, variation in enumerate(variations):
  396. temp_params = base_params.copy()
  397. temp_params[param_name] = variation
  398. temp_df = pd.DataFrame([temp_params])
  399. temp_df = temp_df.reindex(columns=self.X_test.columns, fill_value=0)
  400. try:
  401. pred = self.model.predict(temp_df)[0]
  402. param_sensitivity['variations'].append(float(variation))
  403. param_sensitivity['predictions'].append(float(pred))
  404. # 计算相对于训练数据范围的位置
  405. if max_val > min_val:
  406. relative_pos = (variation - min_val) / (max_val - min_val) * 100
  407. else:
  408. relative_pos = 50.0
  409. pred_change = pred - base_prediction
  410. print(f" {param_name}={variation:.3f} (训练数据{relative_pos:.0f}位) → 预测={pred:.6f} (变化={pred_change:+.6f})")
  411. except Exception as e:
  412. print(f" 警告: {param_name}={variation} 预测失败: {str(e)}")
  413. continue
  414. # 计算敏感性指标
  415. if len(param_sensitivity['predictions']) > 1:
  416. predictions = np.array(param_sensitivity['predictions'])
  417. # 指标1: 预测值变化范围
  418. pred_range = np.max(predictions) - np.min(predictions)
  419. # 指标2: 预测值标准差
  420. pred_std = np.std(predictions)
  421. # 指标3: 与基准值的最大偏差
  422. max_deviation = np.max(np.abs(predictions - base_prediction))
  423. # 指标4: 标准化敏感性(考虑参数的变化范围)
  424. param_range = max_val - min_val if max_val > min_val else 1.0
  425. normalized_sensitivity = pred_range / param_range if param_range > 0 else 0
  426. # 使用最大偏差作为主要敏感性指标
  427. param_sensitivity['sensitivity_score'] = float(max_deviation)
  428. param_sensitivity['pred_range'] = float(pred_range)
  429. param_sensitivity['pred_std'] = float(pred_std)
  430. param_sensitivity['normalized_sensitivity'] = float(normalized_sensitivity)
  431. print(f" {param_name} 敏感性得分: {max_deviation:.6f}")
  432. print(f" 预测范围: {pred_range:.6f}, 标准差: {pred_std:.6f}")
  433. print(f" 标准化敏感性: {normalized_sensitivity:.6f}")
  434. else:
  435. param_sensitivity['sensitivity_score'] = 0.0
  436. sensitivity_results[param_name] = param_sensitivity
  437. # 按敏感性得分排序
  438. sorted_sensitivity = dict(sorted(
  439. sensitivity_results.items(),
  440. key=lambda x: x[1]['sensitivity_score'],
  441. reverse=True
  442. ))
  443. self.analysis_results['sensitivity_analysis'] = sorted_sensitivity
  444. print("\n=== 基于训练数据的参数敏感性分析结果 ===")
  445. for param, result in sorted_sensitivity.items():
  446. score = result['sensitivity_score']
  447. normalized_score = result.get('normalized_sensitivity', 0)
  448. if score > 0.001:
  449. sensitivity_level = "高" if score > 0.1 else "中" if score > 0.01 else "低"
  450. else:
  451. sensitivity_level = "极低"
  452. print(f"{param:12}: 敏感性得分 = {score:.6f} ({sensitivity_level})")
  453. print(f" 标准化敏感性 = {normalized_score:.6f}")
  454. return sorted_sensitivity
  455. except Exception as e:
  456. print(f"敏感性分析时出错: {str(e)}")
  457. import traceback
  458. print(traceback.format_exc())
  459. # 如果分析失败,回退到原来的方法
  460. return self.sensitivity_analysis_fallback(sample_params)
  461. def generate_visualizations(self):
  462. """
  463. 生成分析可视化图表
  464. """
  465. try:
  466. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  467. model_id = self.model_info['ModelID']
  468. # 1. 预测值 vs 真实值散点图
  469. if self.predictions:
  470. plt.figure(figsize=(10, 8))
  471. y_true = self.predictions['y_true']
  472. y_pred = self.predictions['y_pred']
  473. plt.scatter(y_true, y_pred, alpha=0.6, s=50)
  474. # 绘制理想线 (y=x)
  475. min_val = min(min(y_true), min(y_pred))
  476. max_val = max(max(y_true), max(y_pred))
  477. plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')
  478. plt.xlabel('真实值 (True Values)')
  479. plt.ylabel('预测值 (Predicted Values)')
  480. plt.title(f'模型 {model_id} 预测性能\nR^2 = {self.analysis_results["metrics"]["R2_Score"]:.4f}')
  481. plt.legend()
  482. plt.grid(True, alpha=0.3)
  483. plt.tight_layout()
  484. plt.savefig(f'{self.output_dir}/model_{model_id}_scatter_{timestamp}.png', dpi=300)
  485. plt.show()
  486. # 2. 残差图
  487. plt.figure(figsize=(12, 5))
  488. # 确保数据类型一致,转换为numpy数组
  489. y_true_array = np.array(y_true).flatten()
  490. y_pred_array = np.array(y_pred).flatten()
  491. residuals = y_true_array - y_pred_array
  492. # 移除NaN值
  493. valid_indices = ~(np.isnan(residuals) | np.isnan(y_pred_array))
  494. residuals_clean = residuals[valid_indices]
  495. y_pred_clean = y_pred_array[valid_indices]
  496. # 残差散点图
  497. plt.subplot(1, 2, 1)
  498. plt.scatter(y_pred_clean, residuals_clean, alpha=0.6)
  499. plt.axhline(y=0, color='r', linestyle='--', linewidth=2)
  500. plt.xlabel('预测值')
  501. plt.ylabel('残差 (真实值 - 预测值)')
  502. plt.title('残差散点图')
  503. plt.grid(True, alpha=0.3)
  504. # 残差直方图
  505. plt.subplot(1, 2, 2)
  506. # 根据数据量自动调整bins数量
  507. n_samples = len(residuals_clean)
  508. n_bins = min(30, max(5, int(np.sqrt(n_samples)))) # Sturges' rule的变体
  509. n, bins, patches = plt.hist(residuals_clean, bins=n_bins, alpha=0.7, edgecolor='black', density=False)
  510. plt.xlabel('残差')
  511. plt.ylabel('频数')
  512. plt.title(f'残差分布 (n={n_samples}, bins={n_bins})')
  513. plt.grid(True, alpha=0.3)
  514. # 添加统计信息
  515. mean_residual = np.mean(residuals_clean)
  516. std_residual = np.std(residuals_clean)
  517. plt.axvline(mean_residual, color='red', linestyle='--', alpha=0.8,
  518. label=f'均值: {mean_residual:.4f}')
  519. plt.axvline(mean_residual + std_residual, color='orange', linestyle=':', alpha=0.8,
  520. label=f'±1σ: ±{std_residual:.4f}')
  521. plt.axvline(mean_residual - std_residual, color='orange', linestyle=':', alpha=0.8)
  522. plt.legend(fontsize=8)
  523. plt.tight_layout()
  524. plt.savefig(f'{self.output_dir}/model_{model_id}_residuals_{timestamp}.png', dpi=300)
  525. plt.show()
  526. # 3. 特征重要性图
  527. if 'feature_importance' in self.analysis_results:
  528. importance_data = self.analysis_results['feature_importance']
  529. if 'model_importance' in importance_data:
  530. plt.figure(figsize=(10, 6))
  531. features = importance_data['model_importance']['features']
  532. importance = importance_data['model_importance']['importance']
  533. indices = np.argsort(importance)[::-1]
  534. plt.bar(range(len(features)), [importance[i] for i in indices])
  535. plt.xticks(range(len(features)), [features[i] for i in indices], rotation=45)
  536. plt.xlabel('特征')
  537. plt.ylabel('重要性')
  538. plt.title(f'模型 {model_id} 特征重要性')
  539. plt.tight_layout()
  540. plt.savefig(f'{self.output_dir}/model_{model_id}_feature_importance_{timestamp}.png', dpi=300)
  541. plt.show()
  542. # 4. 敏感性分析图
  543. if 'sensitivity_analysis' in self.analysis_results:
  544. sensitivity_data = self.analysis_results['sensitivity_analysis']
  545. # 敏感性得分条形图
  546. plt.figure(figsize=(10, 6))
  547. params = list(sensitivity_data.keys())
  548. scores = [sensitivity_data[param]['sensitivity_score'] for param in params]
  549. plt.bar(params, scores)
  550. plt.xlabel('参数')
  551. plt.ylabel('敏感性得分')
  552. plt.title(f'模型 {model_id} 参数敏感性分析')
  553. plt.xticks(rotation=45)
  554. plt.tight_layout()
  555. plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_{timestamp}.png', dpi=300)
  556. plt.show()
  557. # 详细敏感性曲线
  558. n_params = len(params)
  559. fig, axes = plt.subplots(2, 3, figsize=(15, 10))
  560. axes = axes.flatten()
  561. for i, param in enumerate(params[:6]): # 显示前6个参数
  562. if i < len(axes):
  563. data = sensitivity_data[param]
  564. axes[i].plot(data['variations'], data['predictions'], 'b-o', markersize=4)
  565. axes[i].set_xlabel(param)
  566. axes[i].set_ylabel('预测值')
  567. axes[i].set_title(f'{param} 敏感性')
  568. axes[i].grid(True, alpha=0.3)
  569. # 隐藏多余的子图
  570. for i in range(len(params), len(axes)):
  571. axes[i].set_visible(False)
  572. plt.tight_layout()
  573. plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_curves_{timestamp}.png', dpi=300)
  574. plt.show()
  575. print(f"=== 可视化图表已保存到 {self.output_dir} ===")
  576. except Exception as e:
  577. print(f"生成可视化时出错: {str(e)}")
  578. def generate_report(self):
  579. """
  580. 生成综合分析报告
  581. @return {str} - 报告内容
  582. """
  583. try:
  584. timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  585. model_id = self.model_info['ModelID']
  586. report = f"""
  587. 模型分析报告
  588. {'='*60}
  589. 生成时间: {timestamp}
  590. 分析模型: {self.model_info['Model_name']} (ID: {model_id})
  591. 模型类型: {self.model_info['Model_type']}
  592. 数据类型: {self.model_info['Data_type']}
  593. 创建时间: {self.model_info['Created_at']}
  594. 模型描述:
  595. {self.model_info['Description'] or '无描述'}
  596. {'='*60}
  597. 性能指标分析
  598. {'='*60}
  599. """
  600. if 'metrics' in self.analysis_results:
  601. metrics = self.analysis_results['metrics']
  602. for key, value in metrics.items():
  603. report += f"{key:15}: {value:8.4f}\n"
  604. report += f"\n{'='*60}\n特征重要性分析\n{'='*60}\n"
  605. if 'feature_importance' in self.analysis_results:
  606. importance_data = self.analysis_results['feature_importance']
  607. if 'model_importance' in importance_data:
  608. features = importance_data['model_importance']['features']
  609. importance = importance_data['model_importance']['importance']
  610. # 按重要性排序
  611. sorted_indices = np.argsort(importance)[::-1]
  612. for i in sorted_indices:
  613. report += f"{features[i]:15}: {importance[i]:8.4f}\n"
  614. report += f"\n{'='*60}\n敏感性分析\n{'='*60}\n"
  615. if 'sensitivity_analysis' in self.analysis_results:
  616. sensitivity_data = self.analysis_results['sensitivity_analysis']
  617. for param, data in sensitivity_data.items():
  618. report += f"{param:15}: {data['sensitivity_score']:.8f}\n"
  619. if 'prediction_behavior' in self.analysis_results:
  620. behavior = self.analysis_results['prediction_behavior']
  621. if 'sample_prediction' in behavior:
  622. sample = behavior['sample_prediction']
  623. report += f"\n{'='*60}\n示例预测分析\n{'='*60}\n"
  624. report += f"输入参数: {sample['input_params']}\n"
  625. report += f"预测结果: {sample['prediction']:.4f}\n"
  626. report += f"\n{'='*60}\n分析建议\n{'='*60}\n"
  627. # 生成分析建议
  628. if 'metrics' in self.analysis_results:
  629. r2 = self.analysis_results['metrics']['R2_Score']
  630. if r2 > 0.9:
  631. report += "• 模型表现优秀,R²得分大于0.9\n"
  632. elif r2 > 0.8:
  633. report += "• 模型表现良好,R²得分在0.8-0.9之间\n"
  634. elif r2 > 0.7:
  635. report += "• 模型表现中等,R²得分在0.7-0.8之间,建议优化\n"
  636. else:
  637. report += "• 模型表现较差,R²得分低于0.7,需要重新训练\n"
  638. if 'sensitivity_analysis' in self.analysis_results:
  639. sensitivity_data = self.analysis_results['sensitivity_analysis']
  640. most_sensitive = max(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score'])
  641. report += f"• 最敏感参数: {most_sensitive[0]},在调参时需要特别注意\n"
  642. least_sensitive = min(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score'])
  643. report += f"• 最不敏感参数: {least_sensitive[0]},对预测结果影响较小\n"
  644. report += f"\n{'='*60}\n报告结束\n{'='*60}\n"
  645. # 保存报告
  646. report_path = f'{self.output_dir}/model_{model_id}_analysis_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
  647. with open(report_path, 'w', encoding='utf-8') as f:
  648. f.write(report)
  649. print(f"分析报告已保存到: {report_path}")
  650. return report
  651. except Exception as e:
  652. print(f"生成报告时出错: {str(e)}")
  653. return ""
  654. def run_full_analysis(self, model_id, sample_params=None):
  655. """
  656. 运行完整的模型分析流程
  657. @param {int} model_id - 模型ID
  658. @param {dict} sample_params - 示例参数
  659. @return {dict} - 完整的分析结果
  660. """
  661. print(f"开始分析模型 {model_id}...")
  662. # 1. 加载模型
  663. if not self.load_model_by_id(model_id):
  664. return {}
  665. # 2. 加载测试数据
  666. if not self.load_test_data():
  667. print("警告: 无法加载测试数据,部分分析功能将不可用")
  668. # 3. 加载训练数据(用于敏感性分析)
  669. if not self.load_training_data():
  670. print("警告: 无法加载训练数据,敏感性分析将使用默认方法")
  671. # 4. 计算性能指标
  672. self.calculate_detailed_metrics()
  673. # 5. 特征重要性分析
  674. self.analyze_feature_importance()
  675. # 6. 预测行为分析
  676. self.analyze_prediction_behavior(sample_params)
  677. # 7. 敏感性分析
  678. if sample_params:
  679. self.sensitivity_analysis(sample_params)
  680. # 8. 生成可视化
  681. self.generate_visualizations()
  682. # 9. 生成报告
  683. self.generate_report()
  684. print("分析完成!")
  685. return self.analysis_results
  686. def main():
  687. """
  688. 主函数 - 分析模型24的示例
  689. """
  690. # 创建分析器实例
  691. analyzer = ModelAnalyzer()
  692. # 用户提供的示例参数
  693. sample_params = {
  694. "OM": 19.913,
  695. "CL": 373.600,
  696. "CEC": 7.958,
  697. "H_plus": 0.774,
  698. "N": 0.068,
  699. "Al3_plus": 3.611,
  700. "target_pH": 7.0 # 这个参数在reflux模型中会被移除
  701. }
  702. # 运行完整分析
  703. results = analyzer.run_full_analysis(model_id=24, sample_params=sample_params)
  704. if results:
  705. print("\n=== 分析完成 ===")
  706. print("所有分析结果和可视化图表已保存到 model_optimize/analysis_results 目录")
  707. else:
  708. print("分析失败,请检查模型ID和数据文件")
  709. if __name__ == "__main__":
  710. main()