# -*- coding: utf-8 -*- """ 模型分析工具 此模块提供了综合的模型分析功能,包括: - 模型性能详细分析 - 特征重要性分析 - 预测行为分析 - 参数敏感性分析 - 模型诊断和可视化 """ import os import sys import pickle import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error from sklearn.model_selection import cross_val_score from sklearn.inspection import permutation_importance import sqlite3 from datetime import datetime import warnings warnings.filterwarnings('ignore') # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False class ModelAnalyzer: """ 模型分析器类 提供全面的模型分析功能,支持模型性能评估、特征分析、预测分析等 """ def __init__(self, db_path='SoilAcidification.db'): """ 初始化模型分析器 @param {str} db_path - 数据库文件路径 """ self.db_path = db_path self.model_info = None self.model = None self.test_data = None self.predictions = None self.analysis_results = {} # 创建输出目录 self.output_dir = 'model_optimize/analysis_results' if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) def load_model_by_id(self, model_id): """ 根据模型ID加载模型信息和模型对象 @param {int} model_id - 模型ID @return {bool} - 是否成功加载 """ try: # 连接数据库 conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 查询模型信息 cursor.execute(""" SELECT ModelID, Model_name, Model_type, ModelFilePath, Data_type, Performance_score, MAE, RMSE, CV_score, Description, Created_at FROM Models WHERE ModelID = ? """, (model_id,)) result = cursor.fetchone() if not result: print(f"未找到ID为 {model_id} 的模型") return False # 存储模型信息 self.model_info = { 'ModelID': result[0], 'Model_name': result[1], 'Model_type': result[2], 'ModelFilePath': result[3], 'Data_type': result[4], 'Performance_score': result[5], 'MAE': result[6], 'RMSE': result[7], 'CV_score': result[8], 'Description': result[9], 'Created_at': result[10] } conn.close() # 加载模型文件 if os.path.exists(self.model_info['ModelFilePath']): with open(self.model_info['ModelFilePath'], 'rb') as f: self.model = pickle.load(f) print(f"成功加载模型: {self.model_info['Model_name']}") return True else: print(f"模型文件不存在: {self.model_info['ModelFilePath']}") return False except Exception as e: print(f"加载模型时出错: {str(e)}") return False def load_test_data(self): """ 根据模型数据类型加载相应的测试数据 @return {bool} - 是否成功加载测试数据 """ try: data_type = self.model_info['Data_type'] if data_type == 'reflux': X_test_path = 'uploads/data/X_test_reflux.csv' Y_test_path = 'uploads/data/Y_test_reflux.csv' elif data_type == 'reduce': X_test_path = 'uploads/data/X_test_reduce.csv' Y_test_path = 'uploads/data/Y_test_reduce.csv' else: print(f"不支持的数据类型: {data_type}") return False if os.path.exists(X_test_path) and os.path.exists(Y_test_path): self.X_test = pd.read_csv(X_test_path) self.Y_test = pd.read_csv(Y_test_path) print(f"成功加载 {data_type} 类型的测试数据") print(f"测试集大小: X_test {self.X_test.shape}, Y_test {self.Y_test.shape}") return True else: print(f"测试数据文件不存在") return False except Exception as e: print(f"加载测试数据时出错: {str(e)}") return False def load_training_data(self): """ 加载训练数据以获取参数的真实分布范围 @return {bool} - 是否成功加载训练数据 """ try: data_type = self.model_info['Data_type'] if data_type == 'reflux': # 加载current_reflux表的数据 conn = sqlite3.connect(self.db_path) training_data = pd.read_sql_query("SELECT * FROM current_reflux", conn) conn.close() # 分离特征和目标变量 if 'Delta_pH' in training_data.columns: self.X_train = training_data.drop(['id', 'Delta_pH'], axis=1, errors='ignore') self.y_train = training_data['Delta_pH'] else: print("未找到目标变量 Delta_pH") return False elif data_type == 'reduce': # 加载current_reduce表的数据 conn = sqlite3.connect(self.db_path) training_data = pd.read_sql_query("SELECT * FROM current_reduce", conn) conn.close() # 分离特征和目标变量 if 'Q_over_b' in training_data.columns: self.X_train = training_data.drop(['id', 'Q_over_b'], axis=1, errors='ignore') self.y_train = training_data['Q_over_b'] else: print("未找到目标变量 Q_over_b") return False else: print(f"不支持的数据类型: {data_type}") return False print(f"成功加载训练数据: {self.X_train.shape}") print(f"训练数据特征统计:") print(self.X_train.describe()) return True except Exception as e: print(f"加载训练数据时出错: {str(e)}") return False def calculate_detailed_metrics(self): """ 计算详细的模型性能指标 @return {dict} - 包含各种性能指标的字典 """ try: if self.model is None or self.X_test is None: print("模型或测试数据未加载") return {} # 获取预测值 y_pred = self.model.predict(self.X_test) y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test # 计算各种指标 metrics = { 'R2_Score': r2_score(y_true, y_pred), 'RMSE': mean_squared_error(y_true, y_pred, squared=False), 'MAE': mean_absolute_error(y_true, y_pred), 'MSE': mean_squared_error(y_true, y_pred), 'MAPE': np.mean(np.abs((y_true - y_pred) / y_true)) * 100, 'Max_Error': np.max(np.abs(y_true - y_pred)), 'Mean_Residual': np.mean(y_true - y_pred), 'Std_Residual': np.std(y_true - y_pred) } self.predictions = {'y_true': y_true, 'y_pred': y_pred} self.analysis_results['metrics'] = metrics # 中文指标名称映射 metric_names_zh = { 'R2_Score': 'R²得分', 'RMSE': '均方根误差', 'MAE': '平均绝对误差', 'MSE': '均方误差', 'MAPE': '平均绝对百分比误差', 'Max_Error': '最大误差', 'Mean_Residual': '平均残差', 'Std_Residual': '残差标准差' } print("=== 模型性能指标 ===") for key, value in metrics.items(): zh_name = metric_names_zh.get(key, key) print(f"{key} ({zh_name}): {value:.4f}") return metrics except Exception as e: print(f"计算性能指标时出错: {str(e)}") return {} def analyze_feature_importance(self): """ 分析特征重要性 @return {dict} - 特征重要性分析结果 """ try: if self.model is None: print("模型未加载") return {} importance_data = {} # 1. 模型内置特征重要性(如果支持) if hasattr(self.model, 'feature_importances_'): importance_data['model_importance'] = { 'features': self.X_test.columns.tolist(), 'importance': self.model.feature_importances_.tolist() } # 2. 排列特征重要性 if self.X_test is not None and self.Y_test is not None: y_true = self.Y_test.iloc[:, 0] if len(self.Y_test.columns) == 1 else self.Y_test perm_importance = permutation_importance( self.model, self.X_test, y_true, n_repeats=10, random_state=42 ) importance_data['permutation_importance'] = { 'features': self.X_test.columns.tolist(), 'importance_mean': perm_importance.importances_mean.tolist(), 'importance_std': perm_importance.importances_std.tolist() } self.analysis_results['feature_importance'] = importance_data print("=== 特征重要性分析完成 ===") return importance_data except Exception as e: print(f"分析特征重要性时出错: {str(e)}") return {} def analyze_prediction_behavior(self, sample_params=None): """ 分析模型的预测行为 @param {dict} sample_params - 示例参数,如用户提供的参数 @return {dict} - 预测行为分析结果 """ try: if self.model is None: print("模型未加载") return {} behavior_analysis = {} # 1. 分析用户提供的示例参数 if sample_params: # 转换参数为DataFrame if self.model_info['Data_type'] == 'reflux': # 为reflux类型,移除target_pH sample_params_clean = sample_params.copy() sample_params_clean.pop('target_pH', None) sample_df = pd.DataFrame([sample_params_clean]) else: sample_df = pd.DataFrame([sample_params]) # 确保列顺序与训练数据一致 if hasattr(self, 'X_test'): sample_df = sample_df.reindex(columns=self.X_test.columns, fill_value=0) prediction = self.model.predict(sample_df)[0] behavior_analysis['sample_prediction'] = { 'input_params': sample_params, 'prediction': float(prediction), 'model_type': self.model_info['Data_type'] } print(f"=== 示例参数预测结果 ===") print(f"输入参数: {sample_params}") print(f"预测结果: {prediction:.4f}") # 2. 预测值分布分析 if self.predictions: y_pred = self.predictions['y_pred'] behavior_analysis['prediction_distribution'] = { 'mean': float(np.mean(y_pred)), 'std': float(np.std(y_pred)), 'min': float(np.min(y_pred)), 'max': float(np.max(y_pred)), 'percentiles': { '25%': float(np.percentile(y_pred, 25)), '50%': float(np.percentile(y_pred, 50)), '75%': float(np.percentile(y_pred, 75)) } } self.analysis_results['prediction_behavior'] = behavior_analysis return behavior_analysis except Exception as e: print(f"分析预测行为时出错: {str(e)}") return {} def sensitivity_analysis(self, sample_params=None): """ 基于训练数据分布的参数敏感性分析 @param {dict} sample_params - 基准参数 @return {dict} - 敏感性分析结果 """ try: if self.model is None or sample_params is None: print("模型未加载或未提供基准参数") return {} # 确保训练数据已加载 if not hasattr(self, 'X_train') or self.X_train is None: print("训练数据未加载,尝试加载...") if not self.load_training_data(): print("无法加载训练数据,使用默认参数范围") return self.sensitivity_analysis_fallback(sample_params) sensitivity_results = {} # 准备基准数据 base_params = sample_params.copy() if self.model_info['Data_type'] == 'reflux': base_params.pop('target_pH', None) base_df = pd.DataFrame([base_params]) if hasattr(self, 'X_test'): base_df = base_df.reindex(columns=self.X_test.columns, fill_value=0) base_prediction = self.model.predict(base_df)[0] print(f"基准预测值: {base_prediction:.6f}") # 对每个参数进行基于训练数据的敏感性分析 for param_name, base_value in base_params.items(): if param_name not in base_df.columns: continue # 检查参数是否在训练数据中存在 if param_name not in self.X_train.columns: print(f"警告: 参数 {param_name} 不在训练数据中,跳过") continue param_sensitivity = { 'base_value': base_value, 'variations': [], 'predictions': [], 'sensitivity_score': 0, 'training_stats': {} } # 获取训练数据中该参数的统计信息 param_series = self.X_train[param_name] param_stats = { 'min': float(param_series.min()), 'max': float(param_series.max()), 'mean': float(param_series.mean()), 'std': float(param_series.std()), 'q25': float(param_series.quantile(0.25)), 'q50': float(param_series.quantile(0.50)), 'q75': float(param_series.quantile(0.75)), 'q05': float(param_series.quantile(0.05)), 'q95': float(param_series.quantile(0.95)) } param_sensitivity['training_stats'] = param_stats print(f"\n分析参数: {param_name} (基准值: {base_value})") print(f" 训练数据范围: [{param_stats['min']:.3f}, {param_stats['max']:.3f}]") print(f" 训练数据均值±标准差: {param_stats['mean']:.3f} ± {param_stats['std']:.3f}") print(f" 训练数据分位数 [5%, 25%, 50%, 75%, 95%]: [{param_stats['q05']:.3f}, {param_stats['q25']:.3f}, {param_stats['q50']:.3f}, {param_stats['q75']:.3f}, {param_stats['q95']:.3f}]") # 基于训练数据分布设计参数变化策略 # 策略1: 使用训练数据的实际范围 min_val = param_stats['min'] max_val = param_stats['max'] mean_val = param_stats['mean'] std_val = param_stats['std'] # 创建合理的参数变化点 variations = [] # 1. 训练数据的关键分位数点 variations.extend([ param_stats['min'], # 最小值 param_stats['q05'], # 5%分位数 param_stats['q25'], # 25%分位数 param_stats['q50'], # 中位数 param_stats['q75'], # 75%分位数 param_stats['q95'], # 95%分位数 param_stats['max'], # 最大值 ]) # 2. 围绕均值的标准差变化 variations.extend([ max(min_val, mean_val - 2*std_val), # 均值-2σ max(min_val, mean_val - std_val), # 均值-σ mean_val, # 均值 min(max_val, mean_val + std_val), # 均值+σ min(max_val, mean_val + 2*std_val), # 均值+2σ ]) # 3. 基准值本身 variations.append(base_value) # 4. 如果基准值在合理范围内,添加围绕基准值的小幅变化 if min_val <= base_value <= max_val: # 计算基准值的相对变化(不超出训练数据范围) range_size = max_val - min_val small_change = range_size * 0.05 # 5%的范围变化 variations.extend([ max(min_val, base_value - small_change), min(max_val, base_value + small_change), ]) # 去重并排序 variations = sorted(list(set(variations))) print(f" 使用参数变化点: {len(variations)} 个") for i, variation in enumerate(variations): temp_params = base_params.copy() temp_params[param_name] = variation temp_df = pd.DataFrame([temp_params]) temp_df = temp_df.reindex(columns=self.X_test.columns, fill_value=0) try: pred = self.model.predict(temp_df)[0] param_sensitivity['variations'].append(float(variation)) param_sensitivity['predictions'].append(float(pred)) # 计算相对于训练数据范围的位置 if max_val > min_val: relative_pos = (variation - min_val) / (max_val - min_val) * 100 else: relative_pos = 50.0 pred_change = pred - base_prediction print(f" {param_name}={variation:.3f} (训练数据{relative_pos:.0f}位) → 预测={pred:.6f} (变化={pred_change:+.6f})") except Exception as e: print(f" 警告: {param_name}={variation} 预测失败: {str(e)}") continue # 计算敏感性指标 if len(param_sensitivity['predictions']) > 1: predictions = np.array(param_sensitivity['predictions']) # 指标1: 预测值变化范围 pred_range = np.max(predictions) - np.min(predictions) # 指标2: 预测值标准差 pred_std = np.std(predictions) # 指标3: 与基准值的最大偏差 max_deviation = np.max(np.abs(predictions - base_prediction)) # 指标4: 标准化敏感性(考虑参数的变化范围) param_range = max_val - min_val if max_val > min_val else 1.0 normalized_sensitivity = pred_range / param_range if param_range > 0 else 0 # 使用最大偏差作为主要敏感性指标 param_sensitivity['sensitivity_score'] = float(max_deviation) param_sensitivity['pred_range'] = float(pred_range) param_sensitivity['pred_std'] = float(pred_std) param_sensitivity['normalized_sensitivity'] = float(normalized_sensitivity) print(f" {param_name} 敏感性得分: {max_deviation:.6f}") print(f" 预测范围: {pred_range:.6f}, 标准差: {pred_std:.6f}") print(f" 标准化敏感性: {normalized_sensitivity:.6f}") else: param_sensitivity['sensitivity_score'] = 0.0 sensitivity_results[param_name] = param_sensitivity # 按敏感性得分排序 sorted_sensitivity = dict(sorted( sensitivity_results.items(), key=lambda x: x[1]['sensitivity_score'], reverse=True )) self.analysis_results['sensitivity_analysis'] = sorted_sensitivity print("\n=== 基于训练数据的参数敏感性分析结果 ===") for param, result in sorted_sensitivity.items(): score = result['sensitivity_score'] normalized_score = result.get('normalized_sensitivity', 0) if score > 0.001: sensitivity_level = "高" if score > 0.1 else "中" if score > 0.01 else "低" else: sensitivity_level = "极低" print(f"{param:12}: 敏感性得分 = {score:.6f} ({sensitivity_level})") print(f" 标准化敏感性 = {normalized_score:.6f}") return sorted_sensitivity except Exception as e: print(f"敏感性分析时出错: {str(e)}") import traceback print(traceback.format_exc()) # 如果分析失败,回退到原来的方法 return self.sensitivity_analysis_fallback(sample_params) def generate_visualizations(self): """ 生成分析可视化图表 """ try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_id = self.model_info['ModelID'] # 1. 预测值 vs 真实值散点图 if self.predictions: plt.figure(figsize=(10, 8)) y_true = self.predictions['y_true'] y_pred = self.predictions['y_pred'] plt.scatter(y_true, y_pred, alpha=0.6, s=50) # 绘制理想线 (y=x) min_val = min(min(y_true), min(y_pred)) max_val = max(max(y_true), max(y_pred)) plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction') plt.xlabel('真实值 (True Values)') plt.ylabel('预测值 (Predicted Values)') plt.title(f'模型 {model_id} 预测性能\nR^2 = {self.analysis_results["metrics"]["R2_Score"]:.4f}') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f'{self.output_dir}/model_{model_id}_scatter_{timestamp}.png', dpi=300) plt.show() # 2. 残差图 plt.figure(figsize=(12, 5)) # 确保数据类型一致,转换为numpy数组 y_true_array = np.array(y_true).flatten() y_pred_array = np.array(y_pred).flatten() residuals = y_true_array - y_pred_array # 移除NaN值 valid_indices = ~(np.isnan(residuals) | np.isnan(y_pred_array)) residuals_clean = residuals[valid_indices] y_pred_clean = y_pred_array[valid_indices] # 残差散点图 plt.subplot(1, 2, 1) plt.scatter(y_pred_clean, residuals_clean, alpha=0.6) plt.axhline(y=0, color='r', linestyle='--', linewidth=2) plt.xlabel('预测值') plt.ylabel('残差 (真实值 - 预测值)') plt.title('残差散点图') plt.grid(True, alpha=0.3) # 残差直方图 plt.subplot(1, 2, 2) # 根据数据量自动调整bins数量 n_samples = len(residuals_clean) n_bins = min(30, max(5, int(np.sqrt(n_samples)))) # Sturges' rule的变体 n, bins, patches = plt.hist(residuals_clean, bins=n_bins, alpha=0.7, edgecolor='black', density=False) plt.xlabel('残差') plt.ylabel('频数') plt.title(f'残差分布 (n={n_samples}, bins={n_bins})') plt.grid(True, alpha=0.3) # 添加统计信息 mean_residual = np.mean(residuals_clean) std_residual = np.std(residuals_clean) plt.axvline(mean_residual, color='red', linestyle='--', alpha=0.8, label=f'均值: {mean_residual:.4f}') plt.axvline(mean_residual + std_residual, color='orange', linestyle=':', alpha=0.8, label=f'±1σ: ±{std_residual:.4f}') plt.axvline(mean_residual - std_residual, color='orange', linestyle=':', alpha=0.8) plt.legend(fontsize=8) plt.tight_layout() plt.savefig(f'{self.output_dir}/model_{model_id}_residuals_{timestamp}.png', dpi=300) plt.show() # 3. 特征重要性图 if 'feature_importance' in self.analysis_results: importance_data = self.analysis_results['feature_importance'] if 'model_importance' in importance_data: plt.figure(figsize=(10, 6)) features = importance_data['model_importance']['features'] importance = importance_data['model_importance']['importance'] indices = np.argsort(importance)[::-1] plt.bar(range(len(features)), [importance[i] for i in indices]) plt.xticks(range(len(features)), [features[i] for i in indices], rotation=45) plt.xlabel('特征') plt.ylabel('重要性') plt.title(f'模型 {model_id} 特征重要性') plt.tight_layout() plt.savefig(f'{self.output_dir}/model_{model_id}_feature_importance_{timestamp}.png', dpi=300) plt.show() # 4. 敏感性分析图 if 'sensitivity_analysis' in self.analysis_results: sensitivity_data = self.analysis_results['sensitivity_analysis'] # 敏感性得分条形图 plt.figure(figsize=(10, 6)) params = list(sensitivity_data.keys()) scores = [sensitivity_data[param]['sensitivity_score'] for param in params] plt.bar(params, scores) plt.xlabel('参数') plt.ylabel('敏感性得分') plt.title(f'模型 {model_id} 参数敏感性分析') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_{timestamp}.png', dpi=300) plt.show() # 详细敏感性曲线 n_params = len(params) fig, axes = plt.subplots(2, 3, figsize=(15, 10)) axes = axes.flatten() for i, param in enumerate(params[:6]): # 显示前6个参数 if i < len(axes): data = sensitivity_data[param] axes[i].plot(data['variations'], data['predictions'], 'b-o', markersize=4) axes[i].set_xlabel(param) axes[i].set_ylabel('预测值') axes[i].set_title(f'{param} 敏感性') axes[i].grid(True, alpha=0.3) # 隐藏多余的子图 for i in range(len(params), len(axes)): axes[i].set_visible(False) plt.tight_layout() plt.savefig(f'{self.output_dir}/model_{model_id}_sensitivity_curves_{timestamp}.png', dpi=300) plt.show() print(f"=== 可视化图表已保存到 {self.output_dir} ===") except Exception as e: print(f"生成可视化时出错: {str(e)}") def generate_report(self): """ 生成综合分析报告 @return {str} - 报告内容 """ try: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") model_id = self.model_info['ModelID'] report = f""" 模型分析报告 {'='*60} 生成时间: {timestamp} 分析模型: {self.model_info['Model_name']} (ID: {model_id}) 模型类型: {self.model_info['Model_type']} 数据类型: {self.model_info['Data_type']} 创建时间: {self.model_info['Created_at']} 模型描述: {self.model_info['Description'] or '无描述'} {'='*60} 性能指标分析 {'='*60} """ if 'metrics' in self.analysis_results: metrics = self.analysis_results['metrics'] for key, value in metrics.items(): report += f"{key:15}: {value:8.4f}\n" report += f"\n{'='*60}\n特征重要性分析\n{'='*60}\n" if 'feature_importance' in self.analysis_results: importance_data = self.analysis_results['feature_importance'] if 'model_importance' in importance_data: features = importance_data['model_importance']['features'] importance = importance_data['model_importance']['importance'] # 按重要性排序 sorted_indices = np.argsort(importance)[::-1] for i in sorted_indices: report += f"{features[i]:15}: {importance[i]:8.4f}\n" report += f"\n{'='*60}\n敏感性分析\n{'='*60}\n" if 'sensitivity_analysis' in self.analysis_results: sensitivity_data = self.analysis_results['sensitivity_analysis'] for param, data in sensitivity_data.items(): report += f"{param:15}: {data['sensitivity_score']:.8f}\n" if 'prediction_behavior' in self.analysis_results: behavior = self.analysis_results['prediction_behavior'] if 'sample_prediction' in behavior: sample = behavior['sample_prediction'] report += f"\n{'='*60}\n示例预测分析\n{'='*60}\n" report += f"输入参数: {sample['input_params']}\n" report += f"预测结果: {sample['prediction']:.4f}\n" report += f"\n{'='*60}\n分析建议\n{'='*60}\n" # 生成分析建议 if 'metrics' in self.analysis_results: r2 = self.analysis_results['metrics']['R2_Score'] if r2 > 0.9: report += "• 模型表现优秀,R²得分大于0.9\n" elif r2 > 0.8: report += "• 模型表现良好,R²得分在0.8-0.9之间\n" elif r2 > 0.7: report += "• 模型表现中等,R²得分在0.7-0.8之间,建议优化\n" else: report += "• 模型表现较差,R²得分低于0.7,需要重新训练\n" if 'sensitivity_analysis' in self.analysis_results: sensitivity_data = self.analysis_results['sensitivity_analysis'] most_sensitive = max(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score']) report += f"• 最敏感参数: {most_sensitive[0]},在调参时需要特别注意\n" least_sensitive = min(sensitivity_data.items(), key=lambda x: x[1]['sensitivity_score']) report += f"• 最不敏感参数: {least_sensitive[0]},对预测结果影响较小\n" report += f"\n{'='*60}\n报告结束\n{'='*60}\n" # 保存报告 report_path = f'{self.output_dir}/model_{model_id}_analysis_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt' with open(report_path, 'w', encoding='utf-8') as f: f.write(report) print(f"分析报告已保存到: {report_path}") return report except Exception as e: print(f"生成报告时出错: {str(e)}") return "" def run_full_analysis(self, model_id, sample_params=None): """ 运行完整的模型分析流程 @param {int} model_id - 模型ID @param {dict} sample_params - 示例参数 @return {dict} - 完整的分析结果 """ print(f"开始分析模型 {model_id}...") # 1. 加载模型 if not self.load_model_by_id(model_id): return {} # 2. 加载测试数据 if not self.load_test_data(): print("警告: 无法加载测试数据,部分分析功能将不可用") # 3. 加载训练数据(用于敏感性分析) if not self.load_training_data(): print("警告: 无法加载训练数据,敏感性分析将使用默认方法") # 4. 计算性能指标 self.calculate_detailed_metrics() # 5. 特征重要性分析 self.analyze_feature_importance() # 6. 预测行为分析 self.analyze_prediction_behavior(sample_params) # 7. 敏感性分析 if sample_params: self.sensitivity_analysis(sample_params) # 8. 生成可视化 self.generate_visualizations() # 9. 生成报告 self.generate_report() print("分析完成!") return self.analysis_results def main(): """ 主函数 - 分析模型24的示例 """ # 创建分析器实例 analyzer = ModelAnalyzer() # 用户提供的示例参数 sample_params = { "OM": 19.913, "CL": 373.600, "CEC": 7.958, "H_plus": 0.774, "N": 0.068, "Al3_plus": 3.611, "target_pH": 7.0 # 这个参数在reflux模型中会被移除 } # 运行完整分析 results = analyzer.run_full_analysis(model_id=24, sample_params=sample_params) if results: print("\n=== 分析完成 ===") print("所有分析结果和可视化图表已保存到 model_optimize/analysis_results 目录") else: print("分析失败,请检查模型ID和数据文件") if __name__ == "__main__": main()