|
@@ -1,11 +1,12 @@
|
|
|
-import pickle
|
|
|
import sqlite3
|
|
|
+from io import BytesIO
|
|
|
+import pickle
|
|
|
|
|
|
-from flask import Blueprint, request, jsonify,current_app
|
|
|
-from werkzeug.security import generate_password_hash
|
|
|
+from flask import Blueprint, request, jsonify, current_app, send_file
|
|
|
+from werkzeug.security import check_password_hash, generate_password_hash
|
|
|
+from werkzeug.utils import secure_filename
|
|
|
|
|
|
-from sklearn.metrics import r2_score
|
|
|
-from .model import predict, train_and_save_model, calculate_model_score
|
|
|
+from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
|
|
|
import pandas as pd
|
|
|
from . import db # 从 app 包导入 db 实例
|
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
@@ -16,10 +17,10 @@ from .utils import create_dynamic_table, allowed_file, infer_column_types, renam
|
|
|
predict_to_Q, Q_to_t_ha, create_kriging
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
import logging
|
|
|
-from sqlalchemy import text, func, MetaData, Table, select
|
|
|
+from sqlalchemy import text, select, MetaData, Table, func
|
|
|
from .tasks import train_model_task
|
|
|
from datetime import datetime
|
|
|
-
|
|
|
+from sklearn.metrics import r2_score
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -27,9 +28,6 @@ logger = logging.getLogger(__name__)
|
|
|
# 创建蓝图 (Blueprint),用于分离路由
|
|
|
bp = Blueprint('routes', __name__)
|
|
|
|
|
|
-# 封装数据库连接函数
|
|
|
-def get_db_connection():
|
|
|
- return sqlite3.connect('software_intro.db')
|
|
|
|
|
|
# 密码加密
|
|
|
def hash_password(password):
|
|
@@ -67,8 +65,11 @@ def check_and_trigger_training(session, dataset_type, dataset_df):
|
|
|
# 计算新增数据前的记录数
|
|
|
previous_count = current_count - new_records
|
|
|
|
|
|
- # 设置阈值
|
|
|
- THRESHOLD = current_app.config['THRESHOLD']
|
|
|
+ # 根据数据集类型选择阈值
|
|
|
+ if dataset_type == 'reduce':
|
|
|
+ THRESHOLD = current_app.config['THRESHOLD_REDUCE']
|
|
|
+ else: # reflux
|
|
|
+ THRESHOLD = current_app.config['THRESHOLD_REFLUX']
|
|
|
|
|
|
# 计算上一个阈值点(基于新增前的数据量)
|
|
|
last_threshold = previous_count // THRESHOLD * THRESHOLD
|
|
@@ -141,6 +142,49 @@ def upload_dataset():
|
|
|
dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
|
|
|
insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
|
|
|
|
|
|
+ # 去除上传数据集内部的重复项
|
|
|
+ original_count = len(dataset_df)
|
|
|
+ dataset_df = dataset_df.drop_duplicates()
|
|
|
+ duplicates_in_file = original_count - len(dataset_df)
|
|
|
+
|
|
|
+ # 检查与现有数据的重复
|
|
|
+ duplicates_with_existing = 0
|
|
|
+ if dataset_type in ['reduce', 'reflux']:
|
|
|
+ # 确定表名
|
|
|
+ table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
|
|
|
+
|
|
|
+ # 从表加载现有数据
|
|
|
+ existing_data = pd.read_sql_table(table_name, session.bind)
|
|
|
+ if 'id' in existing_data.columns:
|
|
|
+ existing_data = existing_data.drop('id', axis=1)
|
|
|
+
|
|
|
+ # 确定用于比较的列
|
|
|
+ compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
|
|
|
+
|
|
|
+ # 计算重复行数
|
|
|
+ original_df_len = len(dataset_df)
|
|
|
+
|
|
|
+ # 使用concat和drop_duplicates找出非重复行
|
|
|
+ all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
|
|
|
+ duplicates_mask = all_data.duplicated(keep='first')
|
|
|
+ duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
|
|
|
+
|
|
|
+ # 保留非重复行
|
|
|
+ dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
|
|
|
+
|
|
|
+ logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
|
|
|
+
|
|
|
+ # 检查与测试集的重叠
|
|
|
+ test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
|
|
|
+
|
|
|
+ # 如果有与测试集重叠的数据,从数据集中移除
|
|
|
+ if test_overlap_count > 0:
|
|
|
+ # 创建一个布尔掩码,标记不在重叠索引中的行
|
|
|
+ mask = ~dataset_df.index.isin(test_overlap_indices)
|
|
|
+ # 应用掩码,只保留不重叠的行
|
|
|
+ dataset_df = dataset_df[mask]
|
|
|
+ logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
|
|
|
+
|
|
|
# 根据 dataset_type 决定插入到哪个已有表
|
|
|
if dataset_type == 'reduce':
|
|
|
insert_data_into_existing_table(session, dataset_df, CurrentReduce)
|
|
@@ -153,15 +197,30 @@ def upload_dataset():
|
|
|
training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
|
|
|
|
|
|
response_data = {
|
|
|
- 'message': f'Dataset {dataset_name} uploaded successfully!',
|
|
|
+ 'message': f'数据集 {dataset_name} 上传成功!',
|
|
|
'dataset_id': new_dataset.Dataset_ID,
|
|
|
'filename': unique_filename,
|
|
|
- 'training_triggered': training_triggered
|
|
|
+ 'training_triggered': training_triggered,
|
|
|
+ 'data_stats': {
|
|
|
+ 'original_count': original_count,
|
|
|
+ 'duplicates_in_file': duplicates_in_file,
|
|
|
+ 'duplicates_with_existing': duplicates_with_existing,
|
|
|
+ 'test_overlap_count': test_overlap_count,
|
|
|
+ 'final_count': len(dataset_df)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if training_triggered:
|
|
|
response_data['task_id'] = task_id
|
|
|
- response_data['message'] += ' Auto-training has been triggered.'
|
|
|
+ response_data['message'] += ' 自动训练已触发。'
|
|
|
+
|
|
|
+ # 添加去重信息到消息中
|
|
|
+ if duplicates_with_existing > 0:
|
|
|
+ response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
|
|
|
+
|
|
|
+ # 添加测试集重叠信息到消息中
|
|
|
+ if test_overlap_count > 0:
|
|
|
+ response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
|
|
|
|
|
|
return jsonify(response_data), 201
|
|
|
|
|
@@ -200,11 +259,22 @@ def train_and_save_model_endpoint():
|
|
|
if model_id:
|
|
|
model_info = session.query(Models).filter(Models.ModelID == model_id).first()
|
|
|
if model_info:
|
|
|
- score = calculate_model_score(model_info)
|
|
|
+ # 计算多种评分指标
|
|
|
+ score_metrics = calculate_model_score(model_info)
|
|
|
# 更新模型评分
|
|
|
- model_info.Performance_score = score
|
|
|
+ model_info.Performance_score = score_metrics['r2']
|
|
|
+ # 添加新的评分指标到数据库
|
|
|
+ model_info.MAE = score_metrics['mae']
|
|
|
+ model_info.RMSE = score_metrics['rmse']
|
|
|
+ # CV_score 已在 train_and_save_model 中设置,此处不再更新
|
|
|
session.commit()
|
|
|
- result = {'model_id': model_id, 'model_score': score}
|
|
|
+ result = {
|
|
|
+ 'model_id': model_id,
|
|
|
+ 'model_score': score_metrics['r2'],
|
|
|
+ 'mae': score_metrics['mae'],
|
|
|
+ 'rmse': score_metrics['rmse'],
|
|
|
+ 'cv_score': result[3]
|
|
|
+ }
|
|
|
|
|
|
# 返回成功响应
|
|
|
return jsonify({
|
|
@@ -266,7 +336,7 @@ def predict_route():
|
|
|
return jsonify({'error': str(e)}), 400
|
|
|
|
|
|
|
|
|
-# 为指定模型计算评分Performance_score,需要提供model_id
|
|
|
+# 为指定模型计算指标评分,需要提供model_id
|
|
|
@bp.route('/score-model/<int:model_id>', methods=['POST'])
|
|
|
def score_model(model_id):
|
|
|
# 创建 sessionmaker 实例
|
|
@@ -278,15 +348,23 @@ def score_model(model_id):
|
|
|
return jsonify({'error': 'Model not found'}), 404
|
|
|
|
|
|
# 计算模型评分
|
|
|
- score = calculate_model_score(model_info)
|
|
|
+ score_metrics = calculate_model_score(model_info)
|
|
|
+
|
|
|
+ # 更新模型记录中的评分(不包括交叉验证得分)
|
|
|
+ model_info.Performance_score = score_metrics['r2']
|
|
|
+ model_info.MAE = score_metrics['mae']
|
|
|
+ model_info.RMSE = score_metrics['rmse']
|
|
|
|
|
|
- # 更新模型记录中的评分
|
|
|
- model_info.Performance_score = score
|
|
|
session.commit()
|
|
|
|
|
|
- return jsonify({'message': 'Model scored successfully', 'score': score}), 200
|
|
|
+ return jsonify({
|
|
|
+ 'message': 'Model scored successfully',
|
|
|
+ 'r2_score': score_metrics['r2'],
|
|
|
+ 'mae': score_metrics['mae'],
|
|
|
+ 'rmse': score_metrics['rmse'],
|
|
|
+ }), 200
|
|
|
except Exception as e:
|
|
|
- logging.error('Failed to process the dataset upload:', exc_info=True)
|
|
|
+ logging.error('Failed to process model scoring:', exc_info=True)
|
|
|
return jsonify({'error': str(e)}), 400
|
|
|
finally:
|
|
|
session.close()
|
|
@@ -379,6 +457,7 @@ def get_model(model_id):
|
|
|
'Description': model.Description,
|
|
|
'Performance_score': float(model.Performance_score) if model.Performance_score else None,
|
|
|
'MAE': float(model.MAE) if model.MAE else None,
|
|
|
+ 'CV_score': float(model.CV_score) if model.CV_score else None,
|
|
|
'RMSE': float(model.RMSE) if model.RMSE else None,
|
|
|
'Data_type': model.Data_type
|
|
|
})
|
|
@@ -393,6 +472,46 @@ def get_model(model_id):
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
+@bp.route('/models', methods=['GET'])
|
|
|
+def get_all_models():
|
|
|
+ """
|
|
|
+ 获取所有模型信息的API接口
|
|
|
+
|
|
|
+ @return: JSON响应
|
|
|
+ """
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+
|
|
|
+ try:
|
|
|
+ models = session.query(Models).all()
|
|
|
+ if models:
|
|
|
+ result = [
|
|
|
+ {
|
|
|
+ 'ModelID': model.ModelID,
|
|
|
+ 'Model_name': model.Model_name,
|
|
|
+ 'Model_type': model.Model_type,
|
|
|
+ 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'Description': model.Description,
|
|
|
+ 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
|
|
|
+ 'MAE': float(model.MAE) if model.MAE else None,
|
|
|
+ 'CV_score': float(model.CV_score) if model.CV_score else None,
|
|
|
+ 'RMSE': float(model.RMSE) if model.RMSE else None,
|
|
|
+ 'Data_type': model.Data_type
|
|
|
+ }
|
|
|
+ for model in models
|
|
|
+ ]
|
|
|
+ return jsonify(result)
|
|
|
+ else:
|
|
|
+ return jsonify({'message': '未找到任何模型'}), 404
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f'获取所有模型信息失败: {str(e)}')
|
|
|
+ return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
|
|
|
+
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
@bp.route('/model-parameters', methods=['GET'])
|
|
|
def get_all_model_parameters():
|
|
|
"""
|
|
@@ -560,6 +679,7 @@ def delete_model_route(model_id):
|
|
|
# 调用原始函数
|
|
|
return delete_model(model_id, delete_dataset=delete_dataset_param)
|
|
|
|
|
|
+
|
|
|
def delete_model(model_id, delete_dataset=False):
|
|
|
"""
|
|
|
删除指定模型的API接口
|
|
@@ -584,16 +704,19 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
session.commit()
|
|
|
|
|
|
# 2. 删除模型文件
|
|
|
- model_file = f"rf_model_{model_id}.pkl"
|
|
|
- model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
|
|
|
- if os.path.exists(model_path):
|
|
|
- try:
|
|
|
+ model_path = model.ModelFilePath
|
|
|
+ try:
|
|
|
+ if os.path.exists(model_path):
|
|
|
os.remove(model_path)
|
|
|
- except OSError as e:
|
|
|
+ else:
|
|
|
# 如果删除文件失败,回滚数据库操作
|
|
|
- session.rollback()
|
|
|
- logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
- return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
+ session.rollback()
|
|
|
+ logger.warning(f'模型文件不存在: {model_path}')
|
|
|
+ except OSError as e:
|
|
|
+ # 如果删除文件失败,回滚数据库操作
|
|
|
+ session.rollback()
|
|
|
+ logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
+ return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
|
|
|
# 3. 如果需要删除关联的数据集
|
|
|
if delete_dataset and dataset_id:
|
|
@@ -613,7 +736,7 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
|
|
|
response_data = {
|
|
|
'message': '模型删除成功',
|
|
|
- 'deleted_files': [model_file]
|
|
|
+ 'deleted_files': [model_path]
|
|
|
}
|
|
|
|
|
|
if delete_dataset:
|
|
@@ -679,11 +802,13 @@ def update_threshold():
|
|
|
更新训练阈值的API接口
|
|
|
|
|
|
@body_param threshold: 新的阈值值(整数)
|
|
|
+ @body_param data_type: 数据类型 ('reduce' 或 'reflux')
|
|
|
@return: JSON响应
|
|
|
"""
|
|
|
try:
|
|
|
data = request.get_json()
|
|
|
new_threshold = data.get('threshold')
|
|
|
+ data_type = data.get('data_type')
|
|
|
|
|
|
# 验证新阈值
|
|
|
if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
|
|
@@ -691,12 +816,22 @@ def update_threshold():
|
|
|
'error': '无效的阈值值,必须为正数'
|
|
|
}), 400
|
|
|
|
|
|
+ # 验证数据类型
|
|
|
+ if data_type not in ['reduce', 'reflux']:
|
|
|
+ return jsonify({
|
|
|
+ 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
|
|
|
+ }), 400
|
|
|
+
|
|
|
# 更新当前应用的阈值配置
|
|
|
- current_app.config['THRESHOLD'] = int(new_threshold)
|
|
|
+ if data_type == 'reduce':
|
|
|
+ current_app.config['THRESHOLD_REDUCE'] = int(new_threshold)
|
|
|
+ else: # reflux
|
|
|
+ current_app.config['THRESHOLD_REFLUX'] = int(new_threshold)
|
|
|
|
|
|
return jsonify({
|
|
|
'success': True,
|
|
|
- 'message': f'阈值已更新为 {new_threshold}',
|
|
|
+ 'message': f'{data_type} 阈值已更新为 {new_threshold}',
|
|
|
+ 'data_type': data_type,
|
|
|
'new_threshold': new_threshold
|
|
|
})
|
|
|
|
|
@@ -712,16 +847,32 @@ def get_threshold():
|
|
|
"""
|
|
|
获取当前训练阈值的API接口
|
|
|
|
|
|
+ @query_param data_type: 可选,数据类型 ('reduce' 或 'reflux')
|
|
|
@return: JSON响应
|
|
|
"""
|
|
|
try:
|
|
|
- current_threshold = current_app.config['THRESHOLD']
|
|
|
- default_threshold = current_app.config['DEFAULT_THRESHOLD']
|
|
|
+ data_type = request.args.get('data_type')
|
|
|
|
|
|
- return jsonify({
|
|
|
- 'current_threshold': current_threshold,
|
|
|
- 'default_threshold': default_threshold
|
|
|
- })
|
|
|
+ if data_type and data_type not in ['reduce', 'reflux']:
|
|
|
+ return jsonify({
|
|
|
+ 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
|
|
|
+ }), 400
|
|
|
+
|
|
|
+ response = {}
|
|
|
+
|
|
|
+ if data_type == 'reduce' or data_type is None:
|
|
|
+ response['reduce'] = {
|
|
|
+ 'current_threshold': current_app.config['THRESHOLD_REDUCE'],
|
|
|
+ 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REDUCE']
|
|
|
+ }
|
|
|
+
|
|
|
+ if data_type == 'reflux' or data_type is None:
|
|
|
+ response['reflux'] = {
|
|
|
+ 'current_threshold': current_app.config['THRESHOLD_REFLUX'],
|
|
|
+ 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REFLUX']
|
|
|
+ }
|
|
|
+
|
|
|
+ return jsonify(response)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"获取阈值失败: {str(e)}")
|
|
@@ -729,6 +880,7 @@ def get_threshold():
|
|
|
'error': f'获取阈值失败: {str(e)}'
|
|
|
}), 500
|
|
|
|
|
|
+
|
|
|
@bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
|
|
|
def set_current_dataset(data_type, dataset_id):
|
|
|
"""
|
|
@@ -790,6 +942,7 @@ def set_current_dataset(data_type, dataset_id):
|
|
|
finally:
|
|
|
session.close()
|
|
|
|
|
|
+
|
|
|
@bp.route('/get-model-history/<string:data_type>', methods=['GET'])
|
|
|
def get_model_history(data_type):
|
|
|
"""
|
|
@@ -850,6 +1003,7 @@ def get_model_history(data_type):
|
|
|
finally:
|
|
|
session.close()
|
|
|
|
|
|
+
|
|
|
@bp.route('/batch-delete-datasets', methods=['POST'])
|
|
|
def batch_delete_datasets():
|
|
|
"""
|
|
@@ -914,6 +1068,7 @@ def batch_delete_datasets():
|
|
|
logger.error(f'批量删除数据集失败: {str(e)}')
|
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
+
|
|
|
@bp.route('/batch-delete-models', methods=['POST'])
|
|
|
def batch_delete_models():
|
|
|
"""
|
|
@@ -979,6 +1134,7 @@ def batch_delete_models():
|
|
|
logger.error(f'批量删除模型失败: {str(e)}')
|
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
|
|
+
|
|
|
@bp.route('/kriging_interpolation', methods=['POST'])
|
|
|
def kriging_interpolation():
|
|
|
try:
|
|
@@ -1001,107 +1157,28 @@ def kriging_interpolation():
|
|
|
except Exception as e:
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
-# 显示切换模型
|
|
|
-@bp.route('/models', methods=['GET'])
|
|
|
-def get_models():
|
|
|
- session = None
|
|
|
- try:
|
|
|
- # 创建 session
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
-
|
|
|
- # 查询所有模型
|
|
|
- models = session.query(Models).all()
|
|
|
-
|
|
|
- logger.debug(f"Models found: {models}") # 打印查询的模型数据
|
|
|
-
|
|
|
- if not models:
|
|
|
- return jsonify({'message': 'No models found'}), 404
|
|
|
-
|
|
|
- # 将模型数据转换为字典列表
|
|
|
- models_list = [
|
|
|
- {
|
|
|
- 'ModelID': model.ModelID,
|
|
|
- 'ModelName': model.Model_name,
|
|
|
- 'ModelType': model.Model_type,
|
|
|
- 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
- 'Description': model.Description,
|
|
|
- 'DatasetID': model.DatasetID,
|
|
|
- 'ModelFilePath': model.ModelFilePath,
|
|
|
- 'DataType': model.Data_type,
|
|
|
- 'PerformanceScore': model.Performance_score,
|
|
|
- 'MAE': model.MAE,
|
|
|
- 'RMSE': model.RMSE
|
|
|
- }
|
|
|
- for model in models
|
|
|
- ]
|
|
|
-
|
|
|
- return jsonify(models_list), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return jsonify({'error': str(e)}), 400
|
|
|
- finally:
|
|
|
- if session:
|
|
|
- session.close()
|
|
|
-
|
|
|
-
|
|
|
-# 定义提供数据库列表,用于展示表格的 API 接口
|
|
|
-@bp.route('/table', methods=['POST'])
|
|
|
-def get_table():
|
|
|
- data = request.get_json()
|
|
|
- table_name = data.get('table')
|
|
|
- if not table_name:
|
|
|
- return jsonify({'error': '需要表名'}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- # 创建 sessionmaker 实例
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
-
|
|
|
- # 动态获取表的元数据
|
|
|
- metadata = MetaData()
|
|
|
- table = Table(table_name, metadata, autoload_with=db.engine)
|
|
|
-
|
|
|
- # 从数据库中查询所有记录
|
|
|
- query = select(table)
|
|
|
- result = session.execute(query).fetchall()
|
|
|
-
|
|
|
- # 将结果转换为列表字典形式
|
|
|
- rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
|
|
|
-
|
|
|
- # 获取列名
|
|
|
- headers = [column.name for column in table.columns]
|
|
|
-
|
|
|
- return jsonify(rows=rows, headers=headers), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return jsonify({'error': str(e)}), 400
|
|
|
- finally:
|
|
|
- # 关闭 session
|
|
|
- session.close()
|
|
|
-
|
|
|
|
|
|
@bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
|
|
|
def get_model_scatter_data(model_id):
|
|
|
"""
|
|
|
获取指定模型的散点图数据(真实值vs预测值)
|
|
|
-
|
|
|
+
|
|
|
@param model_id: 模型ID
|
|
|
@return: JSON响应,包含散点图数据
|
|
|
"""
|
|
|
Session = sessionmaker(bind=db.engine)
|
|
|
session = Session()
|
|
|
-
|
|
|
+
|
|
|
try:
|
|
|
# 查询模型信息
|
|
|
model = session.query(Models).filter_by(ModelID=model_id).first()
|
|
|
if not model:
|
|
|
return jsonify({'error': '未找到指定模型'}), 404
|
|
|
-
|
|
|
+
|
|
|
# 加载模型
|
|
|
with open(model.ModelFilePath, 'rb') as f:
|
|
|
ML_model = pickle.load(f)
|
|
|
-
|
|
|
+
|
|
|
# 根据数据类型加载测试数据
|
|
|
if model.Data_type == 'reflux':
|
|
|
X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
|
|
@@ -1111,23 +1188,23 @@ def get_model_scatter_data(model_id):
|
|
|
Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
|
|
|
else:
|
|
|
return jsonify({'error': '不支持的数据类型'}), 400
|
|
|
-
|
|
|
+
|
|
|
# 获取预测值
|
|
|
y_pred = ML_model.predict(X_test)
|
|
|
-
|
|
|
+
|
|
|
# 生成散点图数据
|
|
|
scatter_data = [
|
|
|
- [float(true), float(pred)]
|
|
|
+ [float(true), float(pred)]
|
|
|
for true, pred in zip(Y_test.iloc[:, 0], y_pred)
|
|
|
]
|
|
|
-
|
|
|
+
|
|
|
# 计算R²分数
|
|
|
r2 = r2_score(Y_test, y_pred)
|
|
|
-
|
|
|
+
|
|
|
# 获取数据范围,用于绘制对角线
|
|
|
y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
|
|
|
y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
|
|
|
-
|
|
|
+
|
|
|
return jsonify({
|
|
|
'scatter_data': scatter_data,
|
|
|
'r2_score': float(r2),
|
|
@@ -1135,10 +1212,10 @@ def get_model_scatter_data(model_id):
|
|
|
'model_name': model.Model_name,
|
|
|
'model_type': model.Model_type
|
|
|
}), 200
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
|
|
|
return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
|
|
|
-
|
|
|
+
|
|
|
finally:
|
|
|
session.close()
|