123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876 |
- # -*- coding: utf-8 -*-
- """
- 模型分析工具
- 此模块提供了综合的模型分析功能,包括:
- - 模型性能详细分析
- - 特征重要性分析
- - 预测行为分析
- - 参数敏感性分析
- - 模型诊断和可视化
- """
- import os
- import sys
- import pickle
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import seaborn as sns
- from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
- from sklearn.model_selection import cross_val_score
- from sklearn.inspection import permutation_importance
- import sqlite3
- from datetime import datetime
- import warnings
- warnings.filterwarnings('ignore')
- # 设置中文字体
- plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
- plt.rcParams['axes.unicode_minus'] = False
- class ModelAnalyzer:
- """
- 模型分析器类
-
- 提供全面的模型分析功能,支持模型性能评估、特征分析、预测分析等
- """
-
- def __init__(self, db_path='SoilAcidification.db'):
- """
- 初始化模型分析器
-
- @param {str} db_path - 数据库文件路径
- """
- self.db_path = db_path
- self.model_info = None
- self.model = None
- self.test_data = None
- self.predictions = None
- self.analysis_results = {}
-
- # 创建输出目录
- self.output_dir = 'model_optimize/analysis_results'
- if not os.path.exists(self.output_dir):
- os.makedirs(self.output_dir)
-
- def load_model_by_id(self, model_id):
- """
- 根据模型ID加载模型信息和模型对象
-
- @param {int} model_id - 模型ID
- @return {bool} - 是否成功加载
- """
- try:
- # 连接数据库
- conn = sqlite3.connect(self.db_path)
- cursor = conn.cursor()
-
- # 查询模型信息
- cursor.execute("""
- SELECT ModelID, Model_name, Model_type, ModelFilePath,
- Data_type, Performance_score, MAE, RMSE, CV_score,
- Description, Created_at
- FROM Models
- WHERE ModelID = ?
- """, (model_id,))
-
- result = cursor.fetchone()
- if not result:
- print(f"未找到ID为 {model_id} 的模型")
- return False
-
- # 存储模型信息
- self.model_info = {
- 'ModelID': result[0],
- 'Model_name': result[1],
- 'Model_type': result[2],
- 'ModelFilePath': result[3],
- 'Data_type': result[4],
- 'Performance_score': result[5],
- 'MAE': result[6],
- 'RMSE': result[7],
- 'CV_score': result[8],
- 'Description': result[9],
- 'Created_at': result[10]
- }
-
- conn.close()
-
- # 加载模型文件
- if os.path.exists(self.model_info['ModelFilePath']):
- with open(self.model_info['ModelFilePath'], 'rb') as f:
- self.model = pickle.load(f)
- print(f"成功加载模型: {self.model_info['Model_name']}")
- return True
- else:
- print(f"模型文件不存在: {self.model_info['ModelFilePath']}")
- return False
-
- except Exception as e:
- print(f"加载模型时出错: {str(e)}")
- return False
-
- def load_test_data(self):
- """
- 根据模型数据类型加载相应的测试数据
-
- @return {bool} - 是否成功加载测试数据
- """
- try:
- data_type = self.model_info['Data_type']
-
- if data_type == 'reflux':
- X_test_path = 'uploads/data/X_test_reflux.csv'
- Y_test_path = 'uploads/data/Y_test_reflux.csv'
- elif data_type == 'reduce':
- X_test_path = 'uploads/data/X_test_reduce.csv'
- Y_test_path = 'uploads/data/Y_test_reduce.csv'
- else:
- print(f"不支持的数据类型: {data_type}")
- return False
-
- if os.path.exists(X_test_path) and os.path.exists(Y_test_path):
- self.X_test = pd.read_csv(X_test_path)
- self.Y_test = pd.read_csv(Y_test_path)
- print(f"成功加载 {data_type} 类型的测试数据")
- print(f"测试集大小: X_test {self.X_test.shape}, Y_test {self.Y_test.shape}")
- return True
- else:
- print(f"测试数据文件不存在")
- return False
-
- except Exception as e:
- print(f"加载测试数据时出错: {str(e)}")
- return False
-
- def load_training_data(self):
- """
- 加载训练数据以获取参数的真实分布范围
-
- @return {bool} - 是否成功加载训练数据
- """
- try:
- data_type = self.model_info['Data_type']
-
- if data_type == 'reflux':
- # 加载current_reflux表的数据
- conn = sqlite3.connect(self.db_path)
- training_data = pd.read_sql_query("SELECT * FROM current_reflux", conn)
- conn.close()
-
- # 分离特征和目标变量
- if 'Delta_pH' in training_data.columns:
- self.X_train = training_data.drop(['id', 'Delta_pH'], axis=1, errors='ignore')
- self.y_train = training_data['Delta_pH']
- else:
- print("未找到目标变量 Delta_pH")
- return False
-
- elif data_type == 'reduce':
- # 加载current_reduce表的数据
- conn = sqlite3.connect(self.db_path)
- training_data = pd.read_sql_query("SELECT * FROM current_reduce", conn)
- conn.close()
-
- # 分离特征和目标变量
- if 'Q_over_b' in training_data.columns:
- self.X_train = training_data.drop(['id', 'Q_over_b'], axis=1, errors='ignore')
- self.y_train = training_data['Q_over_b']
- else:
- print("未找到目标变量 Q_over_b")
- return False
- else:
- print(f"不支持的数据类型: {data_type}")
- return False
-
- print(f"成功加载训练数据: {self.X_train.shape}")
- print(f"训练数据特征统计:")
- print(self.X_train.describe())
- return True
-
- except Exception as e:
- print(f"加载训练数据时出错: {str(e)}")
- return False
-
- def calculate_detailed_metrics(self):
- """
- 计算详细的模型性能指标
-
- @return {dict} - 包含各种性能指标的字典
- """
- try:
- if self.model is None or self.X_test is None:
- print("模型或测试数据未加载")
- return {}
-
- # 获取预测值
- y_pred = self.model.predict(self.X_test)
- y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test
-
- # 计算各种指标
- metrics = {
- 'R2_Score': r2_score(y_true, y_pred),
- 'RMSE': mean_squared_error(y_true, y_pred, squared=False),
- 'MAE': mean_absolute_error(y_true, y_pred),
- 'MSE': mean_squared_error(y_true, y_pred),
- 'MAPE': np.mean(np.abs((y_true - y_pred) / y_true)) * 100,
- 'Max_Error': np.max(np.abs(y_true - y_pred)),
- 'Mean_Residual': np.mean(y_true - y_pred),
- 'Std_Residual': np.std(y_true - y_pred)
- }
-
- self.predictions = {'y_true': y_true, 'y_pred': y_pred}
- self.analysis_results['metrics'] = metrics
-
- # 中文指标名称映射
- metric_names_zh = {
- 'R2_Score': 'R²得分',
- 'RMSE': '均方根误差',
- 'MAE': '平均绝对误差',
- 'MSE': '均方误差',
- 'MAPE': '平均绝对百分比误差',
- 'Max_Error': '最大误差',
- 'Mean_Residual': '平均残差',
- 'Std_Residual': '残差标准差'
- }
-
- print("=== 模型性能指标 ===")
- for key, value in metrics.items():
- zh_name = metric_names_zh.get(key, key)
- print(f"{key} ({zh_name}): {value:.4f}")
-
- return metrics
-
- except Exception as e:
- print(f"计算性能指标时出错: {str(e)}")
- return {}
-
- def analyze_feature_importance(self):
- """
- 分析特征重要性
-
- @return {dict} - 特征重要性分析结果
- """
- try:
- if self.model is None:
- print("模型未加载")
- return {}
-
- importance_data = {}
-
- # 1. 模型内置特征重要性(如果支持)
- if hasattr(self.model, 'feature_importances_'):
- importance_data['model_importance'] = {
- 'features': self.X_test.columns.tolist(),
- 'importance': self.model.feature_importances_.tolist()
- }
-
- # 2. 排列特征重要性
- if self.X_test is not None and self.Y_test is not None:
- y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test
- perm_importance = permutation_importance(
- self.model, self.X_test, y_true,
- n_repeats=10, random_state=42
- )
- importance_data['permutation_importance'] = {
- 'features': self.X_test.columns.tolist(),
- 'importance_mean': perm_importance.importances_mean.tolist(),
- 'importance_std': perm_importance.importances_std.tolist()
- }
-
- self.analysis_results['feature_importance'] = importance_data
-
- print("=== 特征重要性分析完成 ===")
- return importance_data
-
- except Exception as e:
- print(f"分析特征重要性时出错: {str(e)}")
- return {}
-
- def analyze_prediction_behavior(self, sample_params=None):
- """
- 分析模型的预测行为
-
- @param {dict} sample_params - 示例参数,如用户提供的参数
- @return {dict} - 预测行为分析结果
- """
- try:
- if self.model is None:
- print("模型未加载")
- return {}
-
- behavior_analysis = {}
-
- # 1. 分析用户提供的示例参数
- if sample_params:
- # 转换参数为DataFrame
- if self.model_info['Data_type'] == 'reflux':
- # 为reflux类型,移除target_pH
- sample_params_clean = sample_params.copy()
- sample_params_clean.pop('target_pH', None)
- sample_df = pd.DataFrame([sample_params_clean])
- else:
- sample_df = pd.DataFrame([sample_params])
-
- # 确保列顺序与训练数据一致
- if hasattr(self, 'X_test'):
- sample_df = sample_df.reindex(columns=self.X_test.columns, fill_value=0)
-
- prediction = self.model.predict(sample_df)[0]
-
- behavior_analysis['sample_prediction'] = {
- 'input_params': sample_params,
- 'prediction': float(prediction),
- 'model_type': self.model_info['Data_type']
- }
-
- print(f"=== 示例参数预测结果 ===")
- print(f"输入参数: {sample_params}")
- print(f"预测结果: {prediction:.4f}")
-
- # 2. 预测值分布分析
- if self.predictions:
- y_pred = self.predictions['y_pred']
- behavior_analysis['prediction_distribution'] = {
- 'mean': float(np.mean(y_pred)),
- 'std': float(np.std(y_pred)),
- 'min': float(np.min(y_pred)),
- 'max': float(np.max(y_pred)),
- 'percentiles': {
- '25%': float(np.percentile(y_pred, 25)),
- '50%': float(np.percentile(y_pred, 50)),
- '75%': float(np.percentile(y_pred, 75))
- }
- }
-
- self.analysis_results['prediction_behavior'] = behavior_analysis
- return behavior_analysis
-
- except Exception as e:
- print(f"分析预测行为时出错: {str(e)}")
- return {}
-
- def sensitivity_analysis(self, sample_params=None):
- """
- 基于训练数据分布的参数敏感性分析
-
- @param {dict} sample_params - 基准参数
- @return {dict} - 敏感性分析结果
- """
- try:
- if self.model is None or sample_params is None:
- print("模型未加载或未提供基准参数")
- return {}
-
- # 确保训练数据已加载
- if not hasattr(self, 'X_train') or self.X_train is None:
- print("训练数据未加载,尝试加载...")
- if not self.load_training_data():
- print("无法加载训练数据,使用默认参数范围")
- return self.sensitivity_analysis_fallback(sample_params)
-
- sensitivity_results = {}
-
- # 准备基准数据
- base_params = sample_params.copy()
- if self.model_info['Data_type'] == 'reflux':
- base_params.pop('target_pH', None)
-
- base_df = pd.DataFrame([base_params])
- if hasattr(self, 'X_test'):
- base_df = base_df.reindex(columns=self.X_test.columns, fill_value=0)
-
- base_prediction = self.model.predict(base_df)[0]
- print(f"基准预测值: {base_prediction:.6f}")
-
- # 对每个参数进行基于训练数据的敏感性分析
- for param_name, base_value in base_params.items():
- if param_name not in base_df.columns:
- continue
-
- # 检查参数是否在训练数据中存在
- if param_name not in self.X_train.columns:
- print(f"警告: 参数 {param_name} 不在训练数据中,跳过")
- continue
-
- param_sensitivity = {
- 'base_value': base_value,
- 'variations': [],
- 'predictions': [],
- 'sensitivity_score': 0,
- 'training_stats': {}
- }
-
- # 获取训练数据中该参数的统计信息
- param_series = self.X_train[param_name]
- param_stats = {
- 'min': float(param_series.min()),
- 'max': float(param_series.max()),
- 'mean': float(param_series.mean()),
- 'std': float(param_series.std()),
- 'q25': float(param_series.quantile(0.25)),
- 'q50': float(param_series.quantile(0.50)),
- 'q75': float(param_series.quantile(0.75)),
- 'q05': float(param_series.quantile(0.05)),
- 'q95': float(param_series.quantile(0.95))
- }
- param_sensitivity['training_stats'] = param_stats
-
- print(f"\n分析参数: {param_name} (基准值: {base_value})")
- print(f" 训练数据范围: [{param_stats['min']:.3f}, {param_stats['max']:.3f}]")
- print(f" 训练数据均值±标准差: {param_stats['mean']:.3f} ± {param_stats['std']:.3f}")
- print(f" 训练数据分位数 [5%, 25%, 50%, 75%, 95%]: [{param_stats['q05']:.3f}, {param_stats['q25']:.3f}, {param_stats['q50']:.3f}, {param_stats['q75']:.3f}, {param_stats['q95']:.3f}]")
-
- # 基于训练数据分布设计参数变化策略
- # 策略1: 使用训练数据的实际范围
- min_val = param_stats['min']
- max_val = param_stats['max']
- mean_val = param_stats['mean']
- std_val = param_stats['std']
-
- # 创建合理的参数变化点
- variations = []
-
- # 1. 训练数据的关键分位数点
- variations.extend([
- param_stats['min'], # 最小值
- param_stats['q05'], # 5%分位数
- param_stats['q25'], # 25%分位数
- param_stats['q50'], # 中位数
- param_stats['q75'], # 75%分位数
- param_stats['q95'], # 95%分位数
- param_stats['max'], # 最大值
- ])
-
- # 2. 围绕均值的标准差变化
- variations.extend([
- max(min_val, mean_val - 2*std_val), # 均值-2σ
- max(min_val, mean_val - std_val), # 均值-σ
- mean_val, # 均值
- min(max_val, mean_val + std_val), # 均值+σ
- min(max_val, mean_val + 2*std_val), # 均值+2σ
- ])
-
- # 3. 基准值本身
- variations.append(base_value)
-
- # 4. 如果基准值在合理范围内,添加围绕基准值的小幅变化
- if min_val <= base_value <= max_val:
- # 计算基准值的相对变化(不超出训练数据范围)
- range_size = max_val - min_val
- small_change = range_size * 0.05 # 5%的范围变化
-
- variations.extend([
- max(min_val, base_value - small_change),
- min(max_val, base_value + small_change),
- ])
-
- # 去重并排序
- variations = sorted(list(set(variations)))
-
- print(f" 使用参数变化点: {len(variations)} 个")
-
- for i, variation in enumerate(variations):
- temp_params = base_params.copy()
- temp_params[param_name] = variation
- temp_df = pd.DataFrame([temp_params])
- temp_df = temp_df.reindex(columns=self.X_test.columns, fill_value=0)
-
- try:
- pred = self.model.predict(temp_df)[0]
- param_sensitivity['variations'].append(float(variation))
- param_sensitivity['predictions'].append(float(pred))
-
- # 计算相对于训练数据范围的位置
- if max_val > min_val:
- relative_pos = (variation - min_val) / (max_val - min_val) * 100
- else:
- relative_pos = 50.0
-
- pred_change = pred - base_prediction
- print(f" {param_name}={variation:.3f} (训练数据{relative_pos:.0f}位) → 预测={pred:.6f} (变化={pred_change:+.6f})")
-
- except Exception as e:
- print(f" 警告: {param_name}={variation} 预测失败: {str(e)}")
- continue
-
- # 计算敏感性指标
- if len(param_sensitivity['predictions']) > 1:
- predictions = np.array(param_sensitivity['predictions'])
-
- # 指标1: 预测值变化范围
- pred_range = np.max(predictions) - np.min(predictions)
-
- # 指标2: 预测值标准差
- pred_std = np.std(predictions)
-
- # 指标3: 与基准值的最大偏差
- max_deviation = np.max(np.abs(predictions - base_prediction))
-
- # 指标4: 标准化敏感性(考虑参数的变化范围)
- param_range = max_val - min_val if max_val > min_val else 1.0
- normalized_sensitivity = pred_range / param_range if param_range > 0 else 0
-
- # 使用最大偏差作为主要敏感性指标
- param_sensitivity['sensitivity_score'] = float(max_deviation)
- param_sensitivity['pred_range'] = float(pred_range)
- param_sensitivity['pred_std'] = float(pred_std)
- param_sensitivity['normalized_sensitivity'] = float(normalized_sensitivity)
-
- print(f" {param_name} 敏感性得分: {max_deviation:.6f}")
- print(f" 预测范围: {pred_range:.6f}, 标准差: {pred_std:.6f}")
- print(f" 标准化敏感性: {normalized_sensitivity:.6f}")
- else:
- param_sensitivity['sensitivity_score'] = 0.0
-
- sensitivity_results[param_name] = param_sensitivity
-
- # 按敏感性得分排序
- sorted_sensitivity = dict(sorted(
- sensitivity_results.items(),
- key=lambda x: x[1]['sensitivity_score'],
- reverse=True
- ))
-
- self.analysis_results['sensitivity_analysis'] = sorted_sensitivity
-
- print("\n=== 基于训练数据的参数敏感性分析结果 ===")
- for param, result in sorted_sensitivity.items():
- score = result['sensitivity_score']
- normalized_score = result.get('normalized_sensitivity', 0)
- if score > 0.001:
- sensitivity_level = "高" if score > 0.1 else "中" if score > 0.01 else "低"
- else:
- sensitivity_level = "极低"
- print(f"{param:12}: 敏感性得分 = {score:.6f} ({sensitivity_level})")
- print(f" 标准化敏感性 = {normalized_score:.6f}")
-
- return sorted_sensitivity
-
- except Exception as e:
- print(f"敏感性分析时出错: {str(e)}")
- import traceback
- print(traceback.format_exc())
- # 如果分析失败,回退到原来的方法
- return self.sensitivity_analysis_fallback(sample_params)
-
- def generate_visualizations(self):
- """
- 生成分析可视化图表
- """
- try:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- model_id = self.model_info['ModelID']
-
- # 1. 预测值 vs 真实值散点图
- if self.predictions:
- plt.figure(figsize=(10, 8))
-
- y_true = self.predictions['y_true']
- y_pred = self.predictions['y_pred']
-
- plt.scatter(y_true, y_pred, alpha=0.6, s=50)
-
- # 绘制理想线 (y=x)
- min_val = min(min(y_true), min(y_pred))
- max_val = max(max(y_true), max(y_pred))
- plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')
-
- plt.xlabel('真实值 (True Values)')
- plt.ylabel('预测值 (Predicted Values)')
- plt.title(f'模型 {model_id} 预测性能\nR^2 = {self.analysis_results["metrics"]["R2_Score"]:.4f}')
- plt.legend()
- plt.grid(True, alpha=0.3)
-
- plt.tight_layout()
- plt.savefig(f'{self.output_dir}/model_{model_id}_scatter_{timestamp}.png', dpi=300)
- plt.show()
-
- # 2. 残差图
- plt.figure(figsize=(12, 5))
-
- # 确保数据类型一致,转换为numpy数组
- y_true_array = np.array(y_true).flatten()
- y_pred_array = np.array(y_pred).flatten()
- residuals = y_true_array - y_pred_array
-
- # 移除NaN值
- valid_indices = ~(np.isnan(residuals) | np.isnan(y_pred_array))
- residuals_clean = residuals[valid_indices]
- y_pred_clean = y_pred_array[valid_indices]
-
- # 残差散点图
- plt.subplot(1, 2, 1)
- plt.scatter(y_pred_clean, residuals_clean, alpha=0.6)
- plt.axhline(y=0, color='r', linestyle='--', linewidth=2)
- plt.xlabel('预测值')
- plt.ylabel('残差 (真实值 - 预测值)')
- plt.title('残差散点图')
- plt.grid(True, alpha=0.3)
-
- # 残差直方图
- plt.subplot(1, 2, 2)
- # 根据数据量自动调整bins数量
- n_samples = len(residuals_clean)
- n_bins = min(30, max(5, int(np.sqrt(n_samples)))) # Sturges' rule的变体
-
- n, bins, patches = plt.hist(residuals_clean, bins=n_bins, alpha=0.7, edgecolor='black', density=False)
- plt.xlabel('残差')
- plt.ylabel('频数')
- plt.title(f'残差分布 (n={n_samples}, bins={n_bins})')
- plt.grid(True, alpha=0.3)
-
- # 添加统计信息
- mean_residual = np.mean(residuals_clean)
- std_residual = np.std(residuals_clean)
- plt.axvline(mean_residual, color='red', linestyle='--', alpha=0.8,
- label=f'均值: {mean_residual:.4f}')
- plt.axvline(mean_residual + std_residual, color='orange', linestyle=':', alpha=0.8,
- label=f'±1σ: ±{std_residual:.4f}')
- plt.axvline(mean_residual - std_residual, color='orange', linestyle=':', alpha=0.8)
- plt.legend(fontsize=8)
-
- plt.tight_layout()
- plt.savefig(f'{self.output_dir}/model_{model_id}_residuals_{timestamp}.png', dpi=300)
- plt.show()
-
- # 3. 特征重要性图
- if 'feature_importance' in self.analysis_results:
- importance_data = self.analysis_results['feature_importance']
-
- if 'model_importance' in importance_data:
- plt.figure(figsize=(10, 6))
- features = importance_data['model_importance']['features']
- importance = importance_data['model_importance']['importance']
-
- indices = np.argsort(importance)[::-1]
-
- plt.bar(range(len(features)), [importance[i] for i in indices])
- plt.xticks(range(len(features)), [features[i] for i in indices], rotation=45)
- plt.xlabel('特征')
- plt.ylabel('重要性')
- plt.title(f'模型 {model_id} 特征重要性')
- plt.tight_layout()
- plt.savefig(f'{self.output_dir}/model_{model_id}_feature_importance_{timestamp}.png', dpi=300)
- plt.show()
-
- # 4. 敏感性分析图
- if 'sensitivity_analysis' in self.analysis_results:
- sensitivity_data = self.analysis_results['sensitivity_analysis']
-
- # 敏感性得分条形图
- plt.figure(figsize=(10, 6))
- params = list(sensitivity_data.keys())
- scores = [sensitivity_data[param]['sensitivity_score'] for param in params]
-
- plt.bar(params, scores)
- plt.xlabel('参数')
- plt.ylabel('敏感性得分')
- plt.title(f'模型 {model_id} 参数敏感性分析')
- plt.xticks(rotation=45)
- plt.tight_layout()
- plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_{timestamp}.png', dpi=300)
- plt.show()
-
- # 详细敏感性曲线
- n_params = len(params)
- fig, axes = plt.subplots(2, 3, figsize=(15, 10))
- axes = axes.flatten()
-
- for i, param in enumerate(params[:6]): # 显示前6个参数
- if i < len(axes):
- data = sensitivity_data[param]
- axes[i].plot(data['variations'], data['predictions'], 'b-o', markersize=4)
- axes[i].set_xlabel(param)
- axes[i].set_ylabel('预测值')
- axes[i].set_title(f'{param} 敏感性')
- axes[i].grid(True, alpha=0.3)
-
- # 隐藏多余的子图
- for i in range(len(params), len(axes)):
- axes[i].set_visible(False)
-
- plt.tight_layout()
- plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_curves_{timestamp}.png', dpi=300)
- plt.show()
-
- print(f"=== 可视化图表已保存到 {self.output_dir} ===")
-
- except Exception as e:
- print(f"生成可视化时出错: {str(e)}")
-
- def generate_report(self):
- """
- 生成综合分析报告
-
- @return {str} - 报告内容
- """
- try:
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- model_id = self.model_info['ModelID']
-
- report = f"""
- 模型分析报告
- {'='*60}
- 生成时间: {timestamp}
- 分析模型: {self.model_info['Model_name']} (ID: {model_id})
- 模型类型: {self.model_info['Model_type']}
- 数据类型: {self.model_info['Data_type']}
- 创建时间: {self.model_info['Created_at']}
- 模型描述:
- {self.model_info['Description'] or '无描述'}
- {'='*60}
- 性能指标分析
- {'='*60}
- """
-
- if 'metrics' in self.analysis_results:
- metrics = self.analysis_results['metrics']
- for key, value in metrics.items():
- report += f"{key:15}: {value:8.4f}\n"
-
- report += f"\n{'='*60}\n特征重要性分析\n{'='*60}\n"
-
- if 'feature_importance' in self.analysis_results:
- importance_data = self.analysis_results['feature_importance']
- if 'model_importance' in importance_data:
- features = importance_data['model_importance']['features']
- importance = importance_data['model_importance']['importance']
-
- # 按重要性排序
- sorted_indices = np.argsort(importance)[::-1]
-
- for i in sorted_indices:
- report += f"{features[i]:15}: {importance[i]:8.4f}\n"
-
- report += f"\n{'='*60}\n敏感性分析\n{'='*60}\n"
-
- if 'sensitivity_analysis' in self.analysis_results:
- sensitivity_data = self.analysis_results['sensitivity_analysis']
- for param, data in sensitivity_data.items():
- report += f"{param:15}: {data['sensitivity_score']:.8f}\n"
-
- if 'prediction_behavior' in self.analysis_results:
- behavior = self.analysis_results['prediction_behavior']
- if 'sample_prediction' in behavior:
- sample = behavior['sample_prediction']
- report += f"\n{'='*60}\n示例预测分析\n{'='*60}\n"
- report += f"输入参数: {sample['input_params']}\n"
- report += f"预测结果: {sample['prediction']:.4f}\n"
-
- report += f"\n{'='*60}\n分析建议\n{'='*60}\n"
-
- # 生成分析建议
- if 'metrics' in self.analysis_results:
- r2 = self.analysis_results['metrics']['R2_Score']
- if r2 > 0.9:
- report += "• 模型表现优秀,R²得分大于0.9\n"
- elif r2 > 0.8:
- report += "• 模型表现良好,R²得分在0.8-0.9之间\n"
- elif r2 > 0.7:
- report += "• 模型表现中等,R²得分在0.7-0.8之间,建议优化\n"
- else:
- report += "• 模型表现较差,R²得分低于0.7,需要重新训练\n"
-
- if 'sensitivity_analysis' in self.analysis_results:
- sensitivity_data = self.analysis_results['sensitivity_analysis']
- most_sensitive = max(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score'])
- report += f"• 最敏感参数: {most_sensitive[0]},在调参时需要特别注意\n"
-
- least_sensitive = min(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score'])
- report += f"• 最不敏感参数: {least_sensitive[0]},对预测结果影响较小\n"
-
- report += f"\n{'='*60}\n报告结束\n{'='*60}\n"
-
- # 保存报告
- report_path = f'{self.output_dir}/model_{model_id}_analysis_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
- with open(report_path, 'w', encoding='utf-8') as f:
- f.write(report)
-
- print(f"分析报告已保存到: {report_path}")
- return report
-
- except Exception as e:
- print(f"生成报告时出错: {str(e)}")
- return ""
-
- def run_full_analysis(self, model_id, sample_params=None):
- """
- 运行完整的模型分析流程
-
- @param {int} model_id - 模型ID
- @param {dict} sample_params - 示例参数
- @return {dict} - 完整的分析结果
- """
- print(f"开始分析模型 {model_id}...")
-
- # 1. 加载模型
- if not self.load_model_by_id(model_id):
- return {}
-
- # 2. 加载测试数据
- if not self.load_test_data():
- print("警告: 无法加载测试数据,部分分析功能将不可用")
-
- # 3. 加载训练数据(用于敏感性分析)
- if not self.load_training_data():
- print("警告: 无法加载训练数据,敏感性分析将使用默认方法")
-
- # 4. 计算性能指标
- self.calculate_detailed_metrics()
-
- # 5. 特征重要性分析
- self.analyze_feature_importance()
-
- # 6. 预测行为分析
- self.analyze_prediction_behavior(sample_params)
-
- # 7. 敏感性分析
- if sample_params:
- self.sensitivity_analysis(sample_params)
-
- # 8. 生成可视化
- self.generate_visualizations()
-
- # 9. 生成报告
- self.generate_report()
-
- print("分析完成!")
- return self.analysis_results
- def main():
- """
- 主函数 - 分析模型24的示例
- """
- # 创建分析器实例
- analyzer = ModelAnalyzer()
-
- # 用户提供的示例参数
- sample_params = {
- "OM": 19.913,
- "CL": 373.600,
- "CEC": 7.958,
- "H_plus": 0.774,
- "N": 0.068,
- "Al3_plus": 3.611,
- "target_pH": 7.0 # 这个参数在reflux模型中会被移除
- }
-
- # 运行完整分析
- results = analyzer.run_full_analysis(model_id=24, sample_params=sample_params)
-
- if results:
- print("\n=== 分析完成 ===")
- print("所有分析结果和可视化图表已保存到 model_optimize/analysis_results 目录")
- else:
- print("分析失败,请检查模型ID和数据文件")
- if __name__ == "__main__":
- main()
|