|
@@ -1,9 +1,5 @@
|
|
|
import sqlite3
|
|
|
-from flask import current_app, send_file
|
|
|
-from werkzeug.security import generate_password_hash, check_password_hash
|
|
|
-from flask import Blueprint, request, jsonify, current_app as app
|
|
|
-from werkzeug.utils import secure_filename
|
|
|
-from io import BytesIO
|
|
|
+from flask import Blueprint, request, jsonify, current_app
|
|
|
from .model import predict, train_and_save_model, calculate_model_score
|
|
|
import pandas as pd
|
|
|
from . import db # 从 app 包导入 db 实例
|
|
@@ -89,6 +85,10 @@ def check_and_trigger_training(session, dataset_type, dataset_df):
|
|
|
|
|
|
@bp.route('/upload-dataset', methods=['POST'])
|
|
|
def upload_dataset():
|
|
|
+ # 创建 session
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+
|
|
|
try:
|
|
|
if 'file' not in request.files:
|
|
|
return jsonify({'error': 'No file part'}), 400
|
|
@@ -102,9 +102,6 @@ def upload_dataset():
|
|
|
if not dataset_type:
|
|
|
return jsonify({'error': 'Dataset type is required'}), 400
|
|
|
|
|
|
- # 创建 sessionmaker 实例
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
new_dataset = Datasets(
|
|
|
Dataset_name=dataset_name,
|
|
|
Dataset_description=dataset_description,
|
|
@@ -161,8 +158,11 @@ def upload_dataset():
|
|
|
session.rollback()
|
|
|
logging.error('Failed to process the dataset upload:', exc_info=True)
|
|
|
return jsonify({'error': str(e)}), 500
|
|
|
+
|
|
|
finally:
|
|
|
- session.close()
|
|
|
+ # 确保 session 总是被关闭
|
|
|
+ if session:
|
|
|
+ session.close()
|
|
|
|
|
|
|
|
|
@bp.route('/train-and-save-model', methods=['POST'])
|
|
@@ -740,13 +740,26 @@ def train_model_async():
|
|
|
'error': 'Missing required parameters'
|
|
|
}), 400
|
|
|
|
|
|
+ # 如果提供了dataset_id,验证数据集是否存在
|
|
|
+ if dataset_id:
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+ try:
|
|
|
+ dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
|
|
|
+ if not dataset:
|
|
|
+ return jsonify({
|
|
|
+ 'error': f'Dataset with ID {dataset_id} not found'
|
|
|
+ }), 404
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
# 启动异步任务
|
|
|
task = train_model_task.delay(
|
|
|
- model_type,
|
|
|
- model_name,
|
|
|
- model_description,
|
|
|
- data_type,
|
|
|
- dataset_id
|
|
|
+ model_type=model_type,
|
|
|
+ model_name=model_name,
|
|
|
+ model_description=model_description,
|
|
|
+ data_type=data_type,
|
|
|
+ dataset_id=dataset_id
|
|
|
)
|
|
|
|
|
|
# 返回任务ID
|
|
@@ -1264,4 +1277,120 @@ def download_template():
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to generate template: {e}", exc_info=True)
|
|
|
return jsonify({'error': '生成模板文件失败'}), 500
|
|
|
-
|
|
|
+
|
|
|
+@bp.route('/update-threshold', methods=['POST'])
|
|
|
+def update_threshold():
|
|
|
+ """
|
|
|
+ 更新训练阈值的API接口
|
|
|
+
|
|
|
+ @body_param threshold: 新的阈值值(整数)
|
|
|
+ @return: JSON响应
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ data = request.get_json()
|
|
|
+ new_threshold = data.get('threshold')
|
|
|
+
|
|
|
+ # 验证新阈值
|
|
|
+ if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
|
|
|
+ return jsonify({
|
|
|
+ 'error': '无效的阈值值,必须为正数'
|
|
|
+ }), 400
|
|
|
+
|
|
|
+ # 更新当前应用的阈值配置
|
|
|
+ current_app.config['THRESHOLD'] = int(new_threshold)
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ 'success': True,
|
|
|
+ 'message': f'阈值已更新为 {new_threshold}',
|
|
|
+ 'new_threshold': new_threshold
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"更新阈值失败: {str(e)}")
|
|
|
+ return jsonify({
|
|
|
+ 'error': f'更新阈值失败: {str(e)}'
|
|
|
+ }), 500
|
|
|
+
|
|
|
+
|
|
|
+@bp.route('/get-threshold', methods=['GET'])
|
|
|
+def get_threshold():
|
|
|
+ """
|
|
|
+ 获取当前训练阈值的API接口
|
|
|
+
|
|
|
+ @return: JSON响应
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ current_threshold = current_app.config['THRESHOLD']
|
|
|
+ default_threshold = current_app.config['DEFAULT_THRESHOLD']
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ 'current_threshold': current_threshold,
|
|
|
+ 'default_threshold': default_threshold
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"获取阈值失败: {str(e)}")
|
|
|
+ return jsonify({
|
|
|
+ 'error': f'获取阈值失败: {str(e)}'
|
|
|
+ }), 500
|
|
|
+
|
|
|
+@bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
|
|
|
+def set_current_dataset(data_type, dataset_id):
|
|
|
+ """
|
|
|
+ 将指定数据集设置为current数据集
|
|
|
+
|
|
|
+ @param data_type: 数据集类型 ('reduce' 或 'reflux')
|
|
|
+ @param dataset_id: 要设置为current的数据集ID
|
|
|
+ @return: JSON响应
|
|
|
+ """
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 验证数据集存在且类型匹配
|
|
|
+ dataset = session.query(Datasets)\
|
|
|
+ .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
|
|
|
+ .first()
|
|
|
+
|
|
|
+ if not dataset:
|
|
|
+ return jsonify({
|
|
|
+ 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
|
|
|
+ }), 404
|
|
|
+
|
|
|
+ # 根据数据类型选择表
|
|
|
+ if data_type == 'reduce':
|
|
|
+ table = CurrentReduce
|
|
|
+ table_name = 'current_reduce'
|
|
|
+ elif data_type == 'reflux':
|
|
|
+ table = CurrentReflux
|
|
|
+ table_name = 'current_reflux'
|
|
|
+ else:
|
|
|
+ return jsonify({'error': '无效的数据集类型'}), 400
|
|
|
+
|
|
|
+ # 清空current表
|
|
|
+ session.query(table).delete()
|
|
|
+
|
|
|
+ # 重置自增主键计数器
|
|
|
+ session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
|
|
|
+
|
|
|
+ # 从指定数据集复制数据到current表
|
|
|
+ dataset_table_name = f"dataset_{dataset_id}"
|
|
|
+ copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
|
|
|
+ session.execute(copy_sql)
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ return jsonify({
|
|
|
+ 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
|
|
|
+ 'dataset_id': dataset_id,
|
|
|
+ 'dataset_name': dataset.Dataset_name,
|
|
|
+ 'row_count': dataset.Row_count
|
|
|
+ }), 200
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ session.rollback()
|
|
|
+ logger.error(f'设置current数据集失败: {str(e)}')
|
|
|
+ return jsonify({'error': str(e)}), 500
|
|
|
+
|
|
|
+ finally:
|
|
|
+ session.close()
|