|
@@ -1,5 +1,6 @@
|
|
|
import sqlite3
|
|
|
from io import BytesIO
|
|
|
+import pickle
|
|
|
|
|
|
from flask import Blueprint, request, jsonify, current_app, send_file
|
|
|
from werkzeug.security import check_password_hash, generate_password_hash
|
|
@@ -19,6 +20,7 @@ import logging
|
|
|
from sqlalchemy import text, select, MetaData, Table, func
|
|
|
from .tasks import train_model_task
|
|
|
from datetime import datetime
|
|
|
+from sklearn.metrics import r2_score
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -897,16 +899,19 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
session.commit()
|
|
|
|
|
|
# 2. 删除模型文件
|
|
|
- model_file = f"rf_model_{model_id}.pkl"
|
|
|
- model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
|
|
|
- if os.path.exists(model_path):
|
|
|
- try:
|
|
|
+ model_path = model.ModelFilePath
|
|
|
+ try:
|
|
|
+ if os.path.exists(model_path):
|
|
|
os.remove(model_path)
|
|
|
- except OSError as e:
|
|
|
+ else:
|
|
|
# 如果删除文件失败,回滚数据库操作
|
|
|
- session.rollback()
|
|
|
- logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
- return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
+ session.rollback()
|
|
|
+ logger.warning(f'模型文件不存在: {model_path}')
|
|
|
+ except OSError as e:
|
|
|
+ # 如果删除文件失败,回滚数据库操作
|
|
|
+ session.rollback()
|
|
|
+ logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
+ return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
|
|
|
# 3. 如果需要删除关联的数据集
|
|
|
if delete_dataset and dataset_id:
|
|
@@ -926,7 +931,7 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
|
|
|
response_data = {
|
|
|
'message': '模型删除成功',
|
|
|
- 'deleted_files': [model_file]
|
|
|
+ 'deleted_files': [model_path]
|
|
|
}
|
|
|
|
|
|
if delete_dataset:
|
|
@@ -1660,4 +1665,66 @@ def kriging_interpolation():
|
|
|
)
|
|
|
return jsonify(result)
|
|
|
except Exception as e:
|
|
|
- return jsonify({"error": str(e)}), 500
|
|
|
+ return jsonify({"error": str(e)}), 500
|
|
|
+
|
|
|
+@bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
|
|
|
+def get_model_scatter_data(model_id):
|
|
|
+ """
|
|
|
+ 获取指定模型的散点图数据(真实值vs预测值)
|
|
|
+
|
|
|
+ @param model_id: 模型ID
|
|
|
+ @return: JSON响应,包含散点图数据
|
|
|
+ """
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 查询模型信息
|
|
|
+ model = session.query(Models).filter_by(ModelID=model_id).first()
|
|
|
+ if not model:
|
|
|
+ return jsonify({'error': '未找到指定模型'}), 404
|
|
|
+
|
|
|
+ # 加载模型
|
|
|
+ with open(model.ModelFilePath, 'rb') as f:
|
|
|
+ ML_model = pickle.load(f)
|
|
|
+
|
|
|
+ # 根据数据类型加载测试数据
|
|
|
+ if model.Data_type == 'reflux':
|
|
|
+ X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
|
|
|
+ Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
|
|
|
+ elif model.Data_type == 'reduce':
|
|
|
+ X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
|
|
|
+ Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
|
|
|
+ else:
|
|
|
+ return jsonify({'error': '不支持的数据类型'}), 400
|
|
|
+
|
|
|
+ # 获取预测值
|
|
|
+ y_pred = ML_model.predict(X_test)
|
|
|
+
|
|
|
+ # 生成散点图数据
|
|
|
+ scatter_data = [
|
|
|
+ [float(true), float(pred)]
|
|
|
+ for true, pred in zip(Y_test.iloc[:, 0], y_pred)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 计算R²分数
|
|
|
+ r2 = r2_score(Y_test, y_pred)
|
|
|
+
|
|
|
+ # 获取数据范围,用于绘制对角线
|
|
|
+ y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
|
|
|
+ y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ 'scatter_data': scatter_data,
|
|
|
+ 'r2_score': float(r2),
|
|
|
+ 'y_range': [float(y_min), float(y_max)],
|
|
|
+ 'model_name': model.Model_name,
|
|
|
+ 'model_type': model.Model_type
|
|
|
+ }), 200
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
|
|
|
+ return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
|
|
|
+
|
|
|
+ finally:
|
|
|
+ session.close()
|