|
@@ -1,12 +1,9 @@
|
|
|
import sqlite3
|
|
|
-from io import BytesIO
|
|
|
-import pickle
|
|
|
|
|
|
-from flask import Blueprint, request, jsonify, current_app, send_file
|
|
|
-from werkzeug.security import check_password_hash, generate_password_hash
|
|
|
-from werkzeug.utils import secure_filename
|
|
|
+from flask import Blueprint, request, jsonify,current_app
|
|
|
+from werkzeug.security import generate_password_hash
|
|
|
|
|
|
-from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
|
|
|
+from .model import predict, train_and_save_model, calculate_model_score
|
|
|
import pandas as pd
|
|
|
from . import db # 从 app 包导入 db 实例
|
|
|
from sqlalchemy.engine.reflection import Inspector
|
|
@@ -17,10 +14,10 @@ from .utils import create_dynamic_table, allowed_file, infer_column_types, renam
|
|
|
predict_to_Q, Q_to_t_ha, create_kriging
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
import logging
|
|
|
-from sqlalchemy import text, select, MetaData, Table, func
|
|
|
+from sqlalchemy import text, func
|
|
|
from .tasks import train_model_task
|
|
|
from datetime import datetime
|
|
|
-from sklearn.metrics import r2_score
|
|
|
+
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -28,6 +25,9 @@ logger = logging.getLogger(__name__)
|
|
|
# 创建蓝图 (Blueprint),用于分离路由
|
|
|
bp = Blueprint('routes', __name__)
|
|
|
|
|
|
+# 封装数据库连接函数
|
|
|
+def get_db_connection():
|
|
|
+ return sqlite3.connect('software_intro.db')
|
|
|
|
|
|
# 密码加密
|
|
|
def hash_password(password):
|
|
@@ -139,49 +139,6 @@ def upload_dataset():
|
|
|
dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
|
|
|
insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
|
|
|
|
|
|
- # 去除上传数据集内部的重复项
|
|
|
- original_count = len(dataset_df)
|
|
|
- dataset_df = dataset_df.drop_duplicates()
|
|
|
- duplicates_in_file = original_count - len(dataset_df)
|
|
|
-
|
|
|
- # 检查与现有数据的重复
|
|
|
- duplicates_with_existing = 0
|
|
|
- if dataset_type in ['reduce', 'reflux']:
|
|
|
- # 确定表名
|
|
|
- table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
|
|
|
-
|
|
|
- # 从表加载现有数据
|
|
|
- existing_data = pd.read_sql_table(table_name, session.bind)
|
|
|
- if 'id' in existing_data.columns:
|
|
|
- existing_data = existing_data.drop('id', axis=1)
|
|
|
-
|
|
|
- # 确定用于比较的列
|
|
|
- compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
|
|
|
-
|
|
|
- # 计算重复行数
|
|
|
- original_df_len = len(dataset_df)
|
|
|
-
|
|
|
- # 使用concat和drop_duplicates找出非重复行
|
|
|
- all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
|
|
|
- duplicates_mask = all_data.duplicated(keep='first')
|
|
|
- duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
|
|
|
-
|
|
|
- # 保留非重复行
|
|
|
- dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
|
|
|
-
|
|
|
- logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
|
|
|
-
|
|
|
- # 检查与测试集的重叠
|
|
|
- test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
|
|
|
-
|
|
|
- # 如果有与测试集重叠的数据,从数据集中移除
|
|
|
- if test_overlap_count > 0:
|
|
|
- # 创建一个布尔掩码,标记不在重叠索引中的行
|
|
|
- mask = ~dataset_df.index.isin(test_overlap_indices)
|
|
|
- # 应用掩码,只保留不重叠的行
|
|
|
- dataset_df = dataset_df[mask]
|
|
|
- logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
|
|
|
-
|
|
|
# 根据 dataset_type 决定插入到哪个已有表
|
|
|
if dataset_type == 'reduce':
|
|
|
insert_data_into_existing_table(session, dataset_df, CurrentReduce)
|
|
@@ -194,30 +151,15 @@ def upload_dataset():
|
|
|
training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
|
|
|
|
|
|
response_data = {
|
|
|
- 'message': f'数据集 {dataset_name} 上传成功!',
|
|
|
+ 'message': f'Dataset {dataset_name} uploaded successfully!',
|
|
|
'dataset_id': new_dataset.Dataset_ID,
|
|
|
'filename': unique_filename,
|
|
|
- 'training_triggered': training_triggered,
|
|
|
- 'data_stats': {
|
|
|
- 'original_count': original_count,
|
|
|
- 'duplicates_in_file': duplicates_in_file,
|
|
|
- 'duplicates_with_existing': duplicates_with_existing,
|
|
|
- 'test_overlap_count': test_overlap_count,
|
|
|
- 'final_count': len(dataset_df)
|
|
|
- }
|
|
|
+ 'training_triggered': training_triggered
|
|
|
}
|
|
|
|
|
|
if training_triggered:
|
|
|
response_data['task_id'] = task_id
|
|
|
- response_data['message'] += ' 自动训练已触发。'
|
|
|
-
|
|
|
- # 添加去重信息到消息中
|
|
|
- if duplicates_with_existing > 0:
|
|
|
- response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
|
|
|
-
|
|
|
- # 添加测试集重叠信息到消息中
|
|
|
- if test_overlap_count > 0:
|
|
|
- response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
|
|
|
+ response_data['message'] += ' Auto-training has been triggered.'
|
|
|
|
|
|
return jsonify(response_data), 201
|
|
|
|
|
@@ -256,22 +198,11 @@ def train_and_save_model_endpoint():
|
|
|
if model_id:
|
|
|
model_info = session.query(Models).filter(Models.ModelID == model_id).first()
|
|
|
if model_info:
|
|
|
- # 计算多种评分指标
|
|
|
- score_metrics = calculate_model_score(model_info)
|
|
|
+ score = calculate_model_score(model_info)
|
|
|
# 更新模型评分
|
|
|
- model_info.Performance_score = score_metrics['r2']
|
|
|
- # 添加新的评分指标到数据库
|
|
|
- model_info.MAE = score_metrics['mae']
|
|
|
- model_info.RMSE = score_metrics['rmse']
|
|
|
- # CV_score 已在 train_and_save_model 中设置,此处不再更新
|
|
|
+ model_info.Performance_score = score
|
|
|
session.commit()
|
|
|
- result = {
|
|
|
- 'model_id': model_id,
|
|
|
- 'model_score': score_metrics['r2'],
|
|
|
- 'mae': score_metrics['mae'],
|
|
|
- 'rmse': score_metrics['rmse'],
|
|
|
- 'cv_score': result[3]
|
|
|
- }
|
|
|
+ result = {'model_id': model_id, 'model_score': score}
|
|
|
|
|
|
# 返回成功响应
|
|
|
return jsonify({
|
|
@@ -333,7 +264,7 @@ def predict_route():
|
|
|
return jsonify({'error': str(e)}), 400
|
|
|
|
|
|
|
|
|
-# 为指定模型计算指标评分,需要提供model_id
|
|
|
+# 为指定模型计算评分Performance_score,需要提供model_id
|
|
|
@bp.route('/score-model/<int:model_id>', methods=['POST'])
|
|
|
def score_model(model_id):
|
|
|
# 创建 sessionmaker 实例
|
|
@@ -345,23 +276,15 @@ def score_model(model_id):
|
|
|
return jsonify({'error': 'Model not found'}), 404
|
|
|
|
|
|
# 计算模型评分
|
|
|
- score_metrics = calculate_model_score(model_info)
|
|
|
-
|
|
|
- # 更新模型记录中的评分(不包括交叉验证得分)
|
|
|
- model_info.Performance_score = score_metrics['r2']
|
|
|
- model_info.MAE = score_metrics['mae']
|
|
|
- model_info.RMSE = score_metrics['rmse']
|
|
|
+ score = calculate_model_score(model_info)
|
|
|
|
|
|
+ # 更新模型记录中的评分
|
|
|
+ model_info.Performance_score = score
|
|
|
session.commit()
|
|
|
|
|
|
- return jsonify({
|
|
|
- 'message': 'Model scored successfully',
|
|
|
- 'r2_score': score_metrics['r2'],
|
|
|
- 'mae': score_metrics['mae'],
|
|
|
- 'rmse': score_metrics['rmse'],
|
|
|
- }), 200
|
|
|
+ return jsonify({'message': 'Model scored successfully', 'score': score}), 200
|
|
|
except Exception as e:
|
|
|
- logging.error('Failed to process model scoring:', exc_info=True)
|
|
|
+ logging.error('Failed to process the dataset upload:', exc_info=True)
|
|
|
return jsonify({'error': str(e)}), 400
|
|
|
finally:
|
|
|
session.close()
|
|
@@ -466,43 +389,6 @@ def get_model(model_id):
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
-@bp.route('/models', methods=['GET'])
|
|
|
-def get_all_models():
|
|
|
- """
|
|
|
- 获取所有模型信息的API接口
|
|
|
-
|
|
|
- @return: JSON响应
|
|
|
- """
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
-
|
|
|
- try:
|
|
|
- models = session.query(Models).all()
|
|
|
- if models:
|
|
|
- result = [
|
|
|
- {
|
|
|
- 'ModelID': model.ModelID,
|
|
|
- 'Model_name': model.Model_name,
|
|
|
- 'Model_type': model.Model_type,
|
|
|
- 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
- 'Description': model.Description,
|
|
|
- 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
|
|
|
- 'Data_type': model.Data_type
|
|
|
- }
|
|
|
- for model in models
|
|
|
- ]
|
|
|
- return jsonify(result)
|
|
|
- else:
|
|
|
- return jsonify({'message': '未找到任何模型'}), 404
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f'获取所有模型信息失败: {str(e)}')
|
|
|
- return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
|
|
|
-
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
-
|
|
|
-
|
|
|
@bp.route('/model-parameters', methods=['GET'])
|
|
|
def get_all_model_parameters():
|
|
|
"""
|
|
@@ -567,288 +453,6 @@ def get_model_parameters(model_id):
|
|
|
return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
|
|
|
|
|
|
|
|
|
-# 定义添加数据库记录的 API 接口
|
|
|
-@bp.route('/add_item', methods=['POST'])
|
|
|
-def add_item():
|
|
|
- """
|
|
|
- 接收 JSON 格式的请求体,包含表名和要插入的数据。
|
|
|
- 尝试将数据插入到指定的表中,并进行字段查重。
|
|
|
- :return:
|
|
|
- """
|
|
|
- try:
|
|
|
- # 确保请求体是 JSON 格式
|
|
|
- data = request.get_json()
|
|
|
- if not data:
|
|
|
- raise ValueError("No JSON data provided")
|
|
|
-
|
|
|
- table_name = data.get('table')
|
|
|
- item_data = data.get('item')
|
|
|
-
|
|
|
- if not table_name or not item_data:
|
|
|
- return jsonify({'error': 'Missing table name or item data'}), 400
|
|
|
-
|
|
|
- # 定义各个表的字段查重规则
|
|
|
- duplicate_check_rules = {
|
|
|
- 'users': ['email', 'username'],
|
|
|
- 'products': ['product_code'],
|
|
|
- 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
|
|
|
- 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
|
|
|
- # 其他表和规则
|
|
|
- }
|
|
|
-
|
|
|
- # 获取该表的查重字段
|
|
|
- duplicate_columns = duplicate_check_rules.get(table_name)
|
|
|
-
|
|
|
- if not duplicate_columns:
|
|
|
- return jsonify({'error': 'No duplicate check rule for this table'}), 400
|
|
|
-
|
|
|
- # 动态构建查询条件,逐一检查是否有重复数据
|
|
|
- condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
|
|
|
- duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
|
|
|
-
|
|
|
- result = db.session.execute(text(duplicate_query), item_data).fetchone()
|
|
|
-
|
|
|
- if result:
|
|
|
- return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
|
|
|
-
|
|
|
- # 动态构建 SQL 语句,进行插入操作
|
|
|
- columns = ', '.join(item_data.keys())
|
|
|
- placeholders = ', '.join([f":{key}" for key in item_data.keys()])
|
|
|
- sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
|
|
-
|
|
|
- # 直接执行插入操作,无需显式的事务管理
|
|
|
- db.session.execute(text(sql), item_data)
|
|
|
-
|
|
|
- # 提交事务
|
|
|
- db.session.commit()
|
|
|
-
|
|
|
- # 返回成功响应
|
|
|
- return jsonify({'success': True, 'message': 'Item added successfully'}), 201
|
|
|
-
|
|
|
- except ValueError as e:
|
|
|
- return jsonify({'error': str(e)}), 400
|
|
|
- except KeyError as e:
|
|
|
- return jsonify({'error': f'Missing data field: {e}'}), 400
|
|
|
- except sqlite3.IntegrityError as e:
|
|
|
- return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
|
|
|
- except sqlite3.Error as e:
|
|
|
- return jsonify({'error': '数据库错误', 'details': str(e)}), 500
|
|
|
-
|
|
|
-
|
|
|
-@bp.route('/delete_item', methods=['POST'])
|
|
|
-def delete_item():
|
|
|
- """
|
|
|
- 删除数据库记录的 API 接口
|
|
|
- """
|
|
|
- data = request.get_json()
|
|
|
- table_name = data.get('table')
|
|
|
- condition = data.get('condition')
|
|
|
-
|
|
|
- # 检查表名和条件是否提供
|
|
|
- if not table_name or not condition:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "缺少表名或条件参数"
|
|
|
- }), 400
|
|
|
-
|
|
|
- # 尝试从条件字符串中解析键和值
|
|
|
- try:
|
|
|
- key, value = condition.split('=')
|
|
|
- key = key.strip() # 去除多余的空格
|
|
|
- value = value.strip().strip("'\"") # 去除多余的空格和引号
|
|
|
- except ValueError:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "条件格式错误,应为 'key=value'"
|
|
|
- }), 400
|
|
|
-
|
|
|
- # 准备 SQL 删除语句
|
|
|
- sql = f"DELETE FROM {table_name} WHERE {key} = :value"
|
|
|
-
|
|
|
- try:
|
|
|
- # 使用 SQLAlchemy 执行删除
|
|
|
- with db.session.begin():
|
|
|
- result = db.session.execute(text(sql), {"value": value})
|
|
|
-
|
|
|
- # 检查是否有记录被删除
|
|
|
- if result.rowcount == 0:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "未找到符合条件的记录"
|
|
|
- }), 404
|
|
|
-
|
|
|
- return jsonify({
|
|
|
- "success": True,
|
|
|
- "message": "记录删除成功"
|
|
|
- }), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": f"删除失败: {e}"
|
|
|
- }), 500
|
|
|
-
|
|
|
-# 定义修改数据库记录的 API 接口
|
|
|
-@bp.route('/update_item', methods=['PUT'])
|
|
|
-def update_record():
|
|
|
- """
|
|
|
- 接收 JSON 格式的请求体,包含表名和更新的数据。
|
|
|
- 尝试更新指定的记录。
|
|
|
- """
|
|
|
- data = request.get_json()
|
|
|
-
|
|
|
- # 检查必要的数据是否提供
|
|
|
- if not data or 'table' not in data or 'item' not in data:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "请求数据不完整"
|
|
|
- }), 400
|
|
|
-
|
|
|
- table_name = data['table']
|
|
|
- item = data['item']
|
|
|
-
|
|
|
- # 假设 item 的第一个键是 ID
|
|
|
- id_key = next(iter(item.keys())) # 获取第一个键
|
|
|
- record_id = item.get(id_key)
|
|
|
-
|
|
|
- if not record_id:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "缺少记录 ID"
|
|
|
- }), 400
|
|
|
-
|
|
|
- # 获取更新的字段和值
|
|
|
- updates = {key: value for key, value in item.items() if key != id_key}
|
|
|
-
|
|
|
- if not updates:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "没有提供需要更新的字段"
|
|
|
- }), 400
|
|
|
-
|
|
|
- # 动态构建 SQL
|
|
|
- set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
|
|
|
- sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
|
|
|
-
|
|
|
- # 添加 ID 到参数
|
|
|
- updates['id_value'] = record_id
|
|
|
-
|
|
|
- try:
|
|
|
- # 使用 SQLAlchemy 执行更新
|
|
|
- with db.session.begin():
|
|
|
- result = db.session.execute(text(sql), updates)
|
|
|
-
|
|
|
- # 检查是否有更新的记录
|
|
|
- if result.rowcount == 0:
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": "未找到要更新的记录"
|
|
|
- }), 404
|
|
|
-
|
|
|
- return jsonify({
|
|
|
- "success": True,
|
|
|
- "message": "数据更新成功"
|
|
|
- }), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- # 捕获所有异常并返回
|
|
|
- return jsonify({
|
|
|
- "success": False,
|
|
|
- "message": f"更新失败: {str(e)}"
|
|
|
- }), 500
|
|
|
-
|
|
|
-
|
|
|
-# 定义查询数据库记录的 API 接口
|
|
|
-@bp.route('/search/record', methods=['GET'])
|
|
|
-def sql_search():
|
|
|
- """
|
|
|
- 接收 JSON 格式的请求体,包含表名和要查询的 ID。
|
|
|
- 尝试查询指定 ID 的记录并返回结果。
|
|
|
- :return:
|
|
|
- """
|
|
|
- try:
|
|
|
- data = request.get_json()
|
|
|
-
|
|
|
- # 表名
|
|
|
- sql_table = data['table']
|
|
|
-
|
|
|
- # 要搜索的 ID
|
|
|
- Id = data['id']
|
|
|
-
|
|
|
- # 连接到数据库
|
|
|
- cur = db.cursor()
|
|
|
-
|
|
|
- # 构造查询语句
|
|
|
- sql = f"SELECT * FROM {sql_table} WHERE id = ?"
|
|
|
-
|
|
|
- # 执行查询
|
|
|
- cur.execute(sql, (Id,))
|
|
|
-
|
|
|
- # 获取查询结果
|
|
|
- rows = cur.fetchall()
|
|
|
- column_names = [desc[0] for desc in cur.description]
|
|
|
-
|
|
|
- # 检查是否有结果
|
|
|
- if not rows:
|
|
|
- return jsonify({'error': '未查找到对应数据。'}), 400
|
|
|
-
|
|
|
- # 构造响应数据
|
|
|
- results = []
|
|
|
- for row in rows:
|
|
|
- result = {column_names[i]: row[i] for i in range(len(row))}
|
|
|
- results.append(result)
|
|
|
-
|
|
|
- # 关闭游标和数据库连接
|
|
|
- cur.close()
|
|
|
- db.close()
|
|
|
-
|
|
|
- # 返回 JSON 响应
|
|
|
- return jsonify(results), 200
|
|
|
-
|
|
|
- except sqlite3.Error as e:
|
|
|
- # 如果发生数据库错误,返回错误信息
|
|
|
- return jsonify({'error': str(e)}), 400
|
|
|
- except KeyError as e:
|
|
|
- # 如果请求数据中缺少必要的键,返回错误信息
|
|
|
- return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
|
|
|
-
|
|
|
-
|
|
|
-# 定义提供数据库列表,用于展示表格的 API 接口
|
|
|
-@bp.route('/table', methods=['POST'])
|
|
|
-def get_table():
|
|
|
- data = request.get_json()
|
|
|
- table_name = data.get('table')
|
|
|
- if not table_name:
|
|
|
- return jsonify({'error': '需要表名'}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- # 创建 sessionmaker 实例
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
-
|
|
|
- # 动态获取表的元数据
|
|
|
- metadata = MetaData()
|
|
|
- table = Table(table_name, metadata, autoload_with=db.engine)
|
|
|
-
|
|
|
- # 从数据库中查询所有记录
|
|
|
- query = select(table)
|
|
|
- result = session.execute(query).fetchall()
|
|
|
-
|
|
|
- # 将结果转换为列表字典形式
|
|
|
- rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
|
|
|
-
|
|
|
- # 获取列名
|
|
|
- headers = [column.name for column in table.columns]
|
|
|
-
|
|
|
- return jsonify(rows=rows, headers=headers), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return jsonify({'error': str(e)}), 400
|
|
|
- finally:
|
|
|
- # 关闭 session
|
|
|
- session.close()
|
|
|
-
|
|
|
-
|
|
|
@bp.route('/train-model-async', methods=['POST'])
|
|
|
def train_model_async():
|
|
|
"""
|
|
@@ -976,19 +580,16 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
session.commit()
|
|
|
|
|
|
# 2. 删除模型文件
|
|
|
- model_path = model.ModelFilePath
|
|
|
- try:
|
|
|
- if os.path.exists(model_path):
|
|
|
+ model_file = f"rf_model_{model_id}.pkl"
|
|
|
+ model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
|
|
|
+ if os.path.exists(model_path):
|
|
|
+ try:
|
|
|
os.remove(model_path)
|
|
|
- else:
|
|
|
+ except OSError as e:
|
|
|
# 如果删除文件失败,回滚数据库操作
|
|
|
- session.rollback()
|
|
|
- logger.warning(f'模型文件不存在: {model_path}')
|
|
|
- except OSError as e:
|
|
|
- # 如果删除文件失败,回滚数据库操作
|
|
|
- session.rollback()
|
|
|
- logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
- return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
+ session.rollback()
|
|
|
+ logger.error(f'删除模型文件失败: {str(e)}')
|
|
|
+ return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
|
|
|
|
|
|
# 3. 如果需要删除关联的数据集
|
|
|
if delete_dataset and dataset_id:
|
|
@@ -1008,7 +609,7 @@ def delete_model(model_id, delete_dataset=False):
|
|
|
|
|
|
response_data = {
|
|
|
'message': '模型删除成功',
|
|
|
- 'deleted_files': [model_path]
|
|
|
+ 'deleted_files': [model_file]
|
|
|
}
|
|
|
|
|
|
if delete_dataset:
|
|
@@ -1068,354 +669,6 @@ def clear_dataset(data_type):
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
-@bp.route('/login', methods=['POST'])
|
|
|
-def login_user():
|
|
|
- # 获取前端传来的数据
|
|
|
- data = request.get_json()
|
|
|
- name = data.get('name') # 用户名
|
|
|
- password = data.get('password') # 密码
|
|
|
-
|
|
|
- logger.info(f"Login request received: name={name}")
|
|
|
-
|
|
|
- # 检查用户名和密码是否为空
|
|
|
- if not name or not password:
|
|
|
- logger.warning("用户名和密码不能为空")
|
|
|
- return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- # 查询数据库验证用户名
|
|
|
- query = "SELECT * FROM users WHERE name = :name"
|
|
|
- conn = get_db()
|
|
|
- user = conn.execute(query, {"name": name}).fetchone()
|
|
|
-
|
|
|
- if not user:
|
|
|
- logger.warning(f"用户名 '{name}' 不存在")
|
|
|
- return jsonify({"success": False, "message": "用户名不存在"}), 400
|
|
|
-
|
|
|
- # 获取数据库中存储的密码(假设密码是哈希存储的)
|
|
|
- stored_password = user[2] # 假设密码存储在数据库的第三列
|
|
|
- user_id = user[0] # 假设 id 存储在数据库的第一列
|
|
|
-
|
|
|
- # 校验密码是否正确
|
|
|
- if check_password_hash(stored_password, password):
|
|
|
- logger.info(f"User '{name}' logged in successfully.")
|
|
|
- return jsonify({
|
|
|
- "success": True,
|
|
|
- "message": "登录成功",
|
|
|
- "userId": user_id # 返回用户 ID
|
|
|
- })
|
|
|
- else:
|
|
|
- logger.warning(f"Invalid password for user '{name}'")
|
|
|
- return jsonify({"success": False, "message": "用户名或密码错误"}), 400
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- # 记录错误日志并返回错误信息
|
|
|
- logger.error(f"Error during login: {e}", exc_info=True)
|
|
|
- return jsonify({"success": False, "message": "登录失败"}), 500
|
|
|
-
|
|
|
-# 更新用户信息接口
|
|
|
-
|
|
|
-
|
|
|
-@bp.route('/update_user', methods=['POST'])
|
|
|
-def update_user():
|
|
|
- # 获取前端传来的数据
|
|
|
- data = request.get_json()
|
|
|
-
|
|
|
- # 打印收到的请求数据
|
|
|
- current_app.logger.info(f"Received data: {data}")
|
|
|
-
|
|
|
- user_id = data.get('userId') # 用户ID
|
|
|
- name = data.get('name') # 用户名
|
|
|
- old_password = data.get('oldPassword') # 旧密码
|
|
|
- new_password = data.get('newPassword') # 新密码
|
|
|
-
|
|
|
- logger.info(f"Update request received: user_id={user_id}, name={name}")
|
|
|
-
|
|
|
- # 校验传入的用户名和密码是否为空
|
|
|
- if not name or not old_password:
|
|
|
- logger.warning("用户名和旧密码不能为空")
|
|
|
- return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
|
|
|
-
|
|
|
- # 新密码和旧密码不能相同
|
|
|
- if new_password and old_password == new_password:
|
|
|
- logger.warning(f"新密码与旧密码相同:{name}")
|
|
|
- return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- # 查询数据库验证用户ID
|
|
|
- query = "SELECT * FROM users WHERE id = :user_id"
|
|
|
- conn = get_db()
|
|
|
- user = conn.execute(query, {"user_id": user_id}).fetchone()
|
|
|
-
|
|
|
- if not user:
|
|
|
- logger.warning(f"用户ID '{user_id}' 不存在")
|
|
|
- return jsonify({"success": False, "message": "用户不存在"}), 400
|
|
|
-
|
|
|
- # 获取数据库中存储的密码(假设密码是哈希存储的)
|
|
|
- stored_password = user[2] # 假设密码存储在数据库的第三列
|
|
|
-
|
|
|
- # 校验旧密码是否正确
|
|
|
- if not check_password_hash(stored_password, old_password):
|
|
|
- logger.warning(f"旧密码错误:{name}")
|
|
|
- return jsonify({"success": False, "message": "旧密码错误"}), 400
|
|
|
-
|
|
|
- # 如果新密码非空,则更新新密码
|
|
|
- if new_password:
|
|
|
- hashed_new_password = hash_password(new_password)
|
|
|
- update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
|
|
|
- conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
|
|
|
- conn.commit()
|
|
|
- logger.info(f"User ID '{user_id}' password updated successfully.")
|
|
|
-
|
|
|
- # 如果用户名发生更改,则更新用户名
|
|
|
- if name != user[1]:
|
|
|
- update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
|
|
|
- conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
|
|
|
- conn.commit()
|
|
|
- logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
|
|
|
-
|
|
|
- return jsonify({"success": True, "message": "用户信息更新成功"})
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- # 记录错误日志并返回错误信息
|
|
|
- logger.error(f"Error updating user: {e}", exc_info=True)
|
|
|
- return jsonify({"success": False, "message": "更新失败"}), 500
|
|
|
-
|
|
|
-
|
|
|
-# 注册用户
|
|
|
-@bp.route('/register', methods=['POST'])
|
|
|
-def register_user():
|
|
|
- # 获取前端传来的数据
|
|
|
- data = request.get_json()
|
|
|
- name = data.get('name') # 用户名
|
|
|
- password = data.get('password') # 密码
|
|
|
-
|
|
|
- logger.info(f"Register request received: name={name}")
|
|
|
-
|
|
|
- # 检查用户名和密码是否为空
|
|
|
- if not name or not password:
|
|
|
- logger.warning("用户名和密码不能为空")
|
|
|
- return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
|
|
|
-
|
|
|
- # 动态获取数据库表的列名
|
|
|
- columns = get_column_names('users')
|
|
|
- logger.info(f"Database columns for 'users' table: {columns}")
|
|
|
-
|
|
|
- # 检查前端传来的数据是否包含数据库表中所有的必填字段
|
|
|
- for column in ['name', 'password']:
|
|
|
- if column not in columns:
|
|
|
- logger.error(f"缺少必填字段:{column}")
|
|
|
- return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
|
|
|
-
|
|
|
- # 对密码进行哈希处理
|
|
|
- hashed_password = hash_password(password)
|
|
|
- logger.info(f"Password hashed for user: {name}")
|
|
|
-
|
|
|
- # 插入到数据库
|
|
|
- try:
|
|
|
- # 检查用户是否已经存在
|
|
|
- query = "SELECT * FROM users WHERE name = :name"
|
|
|
- conn = get_db()
|
|
|
- user = conn.execute(query, {"name": name}).fetchone()
|
|
|
-
|
|
|
- if user:
|
|
|
- logger.warning(f"用户名 '{name}' 已存在")
|
|
|
- return jsonify({"success": False, "message": "用户名已存在"}), 400
|
|
|
-
|
|
|
- # 向数据库插入数据
|
|
|
- query = "INSERT INTO users (name, password) VALUES (:name, :password)"
|
|
|
- conn.execute(query, {"name": name, "password": hashed_password})
|
|
|
- conn.commit()
|
|
|
-
|
|
|
- logger.info(f"User '{name}' registered successfully.")
|
|
|
- return jsonify({"success": True, "message": "注册成功"})
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- # 记录错误日志并返回错误信息
|
|
|
- logger.error(f"Error registering user: {e}", exc_info=True)
|
|
|
- return jsonify({"success": False, "message": "注册失败"}), 500
|
|
|
-
|
|
|
-
|
|
|
-def get_column_names(table_name):
|
|
|
- """
|
|
|
- 动态获取数据库表的列名。
|
|
|
- """
|
|
|
- try:
|
|
|
- conn = get_db()
|
|
|
- query = f"PRAGMA table_info({table_name});"
|
|
|
- result = conn.execute(query).fetchall()
|
|
|
- conn.close()
|
|
|
- return [row[1] for row in result] # 第二列是列名
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
|
|
|
- return []
|
|
|
-
|
|
|
-
|
|
|
-# 导出数据
|
|
|
-@bp.route('/export_data', methods=['GET'])
|
|
|
-def export_data():
|
|
|
- table_name = request.args.get('table')
|
|
|
- file_format = request.args.get('format', 'excel').lower()
|
|
|
-
|
|
|
- if not table_name:
|
|
|
- return jsonify({'error': '缺少表名参数'}), 400
|
|
|
- if not table_name.isidentifier():
|
|
|
- return jsonify({'error': '无效的表名'}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- conn = get_db()
|
|
|
- query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
|
|
|
- table_exists = conn.execute(query, (table_name,)).fetchone()
|
|
|
- if not table_exists:
|
|
|
- return jsonify({'error': f"表 {table_name} 不存在"}), 404
|
|
|
-
|
|
|
- query = f"SELECT * FROM {table_name};"
|
|
|
- df = pd.read_sql(query, conn)
|
|
|
-
|
|
|
- output = BytesIO()
|
|
|
- if file_format == 'csv':
|
|
|
- df.to_csv(output, index=False, encoding='utf-8')
|
|
|
- output.seek(0)
|
|
|
- return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
|
|
|
- elif file_format == 'excel':
|
|
|
- df.to_excel(output, index=False, engine='openpyxl')
|
|
|
- output.seek(0)
|
|
|
- return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
|
|
|
- mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
|
|
|
- else:
|
|
|
- return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Error in export_data: {e}", exc_info=True)
|
|
|
- return jsonify({'error': str(e)}), 500
|
|
|
-
|
|
|
-
|
|
|
-# 导入数据接口
|
|
|
-@bp.route('/import_data', methods=['POST'])
|
|
|
-def import_data():
|
|
|
- logger.debug("Import data endpoint accessed.")
|
|
|
- if 'file' not in request.files:
|
|
|
- logger.error("No file in request.")
|
|
|
- return jsonify({'success': False, 'message': '文件缺失'}), 400
|
|
|
-
|
|
|
- file = request.files['file']
|
|
|
- table_name = request.form.get('table')
|
|
|
-
|
|
|
- if not table_name:
|
|
|
- logger.error("Missing table name parameter.")
|
|
|
- return jsonify({'success': False, 'message': '缺少表名参数'}), 400
|
|
|
-
|
|
|
- if file.filename == '':
|
|
|
- logger.error("No file selected.")
|
|
|
- return jsonify({'success': False, 'message': '未选择文件'}), 400
|
|
|
-
|
|
|
- try:
|
|
|
- # 保存文件到临时路径
|
|
|
- temp_path = os.path.join(current_app.config['UPLOAD_FOLDER'], secure_filename(file.filename))
|
|
|
- file.save(temp_path)
|
|
|
- logger.debug(f"File saved to temporary path: {temp_path}")
|
|
|
-
|
|
|
- # 根据文件类型读取文件
|
|
|
- if file.filename.endswith('.xlsx'):
|
|
|
- df = pd.read_excel(temp_path)
|
|
|
- elif file.filename.endswith('.csv'):
|
|
|
- df = pd.read_csv(temp_path)
|
|
|
- else:
|
|
|
- logger.error("Unsupported file format.")
|
|
|
- return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
|
|
|
-
|
|
|
- # 获取数据库列名
|
|
|
- db_columns = get_column_names(table_name)
|
|
|
- if 'id' in db_columns:
|
|
|
- db_columns.remove('id') # 假设 id 列是自增的,不需要处理
|
|
|
-
|
|
|
- if not set(db_columns).issubset(set(df.columns)):
|
|
|
- logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
|
|
|
- return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
|
|
|
-
|
|
|
- # 清洗数据并删除空值行
|
|
|
- df_cleaned = df[db_columns].dropna()
|
|
|
-
|
|
|
- # 统一数据类型,避免 int 和 float 合并问题
|
|
|
- df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
|
|
|
-
|
|
|
- # 获取现有的数据
|
|
|
- conn = get_db()
|
|
|
- with conn:
|
|
|
- existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
|
|
|
-
|
|
|
- # 查找重复数据
|
|
|
- duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
|
|
|
-
|
|
|
- # 如果有重复数据,删除它们
|
|
|
- df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
|
|
|
- logger.warning(f"Duplicate data detected and removed: {duplicates}")
|
|
|
-
|
|
|
- # 获取导入前后的数据量
|
|
|
- total_data = len(df_cleaned) + len(duplicates)
|
|
|
- new_data = len(df_cleaned)
|
|
|
- duplicate_data = len(duplicates)
|
|
|
-
|
|
|
- # 导入不重复的数据
|
|
|
- df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
|
|
|
- logger.debug(f"Imported {new_data} new records into the database.")
|
|
|
-
|
|
|
- # 删除临时文件
|
|
|
- os.remove(temp_path)
|
|
|
- logger.debug(f"Temporary file removed: {temp_path}")
|
|
|
-
|
|
|
- # 返回结果
|
|
|
- return jsonify({
|
|
|
- 'success': True,
|
|
|
- 'message': '数据导入成功',
|
|
|
- 'total_data': total_data,
|
|
|
- 'new_data': new_data,
|
|
|
- 'duplicate_data': duplicate_data
|
|
|
- }), 200
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Import failed: {e}", exc_info=True)
|
|
|
- return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
|
|
|
-
|
|
|
-
|
|
|
-# 模板下载接口
|
|
|
-@bp.route('/download_template', methods=['GET'])
|
|
|
-def download_template():
|
|
|
- """
|
|
|
- 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
|
|
|
- """
|
|
|
- table_name = request.args.get('table')
|
|
|
- if not table_name:
|
|
|
- return jsonify({'error': '表名参数缺失'}), 400
|
|
|
-
|
|
|
- columns = get_column_names(table_name)
|
|
|
- if not columns:
|
|
|
- return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
|
|
|
-
|
|
|
- # 不包括 ID 列
|
|
|
- if 'id' in columns:
|
|
|
- columns.remove('id')
|
|
|
-
|
|
|
- df = pd.DataFrame(columns=columns)
|
|
|
-
|
|
|
- file_format = request.args.get('format', 'excel').lower()
|
|
|
- try:
|
|
|
- if file_format == 'csv':
|
|
|
- output = BytesIO()
|
|
|
- df.to_csv(output, index=False, encoding='utf-8')
|
|
|
- output.seek(0)
|
|
|
- return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
|
|
|
- mimetype='text/csv')
|
|
|
- else:
|
|
|
- output = BytesIO()
|
|
|
- df.to_excel(output, index=False, engine='openpyxl')
|
|
|
- output.seek(0)
|
|
|
- return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
|
|
|
- mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Failed to generate template: {e}", exc_info=True)
|
|
|
- return jsonify({'error': '生成模板文件失败'}), 500
|
|
|
-
|
|
|
@bp.route('/update-threshold', methods=['POST'])
|
|
|
def update_threshold():
|
|
|
"""
|
|
@@ -1744,64 +997,44 @@ def kriging_interpolation():
|
|
|
except Exception as e:
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
-@bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
|
|
|
-def get_model_scatter_data(model_id):
|
|
|
- """
|
|
|
- 获取指定模型的散点图数据(真实值vs预测值)
|
|
|
-
|
|
|
- @param model_id: 模型ID
|
|
|
- @return: JSON响应,包含散点图数据
|
|
|
- """
|
|
|
- Session = sessionmaker(bind=db.engine)
|
|
|
- session = Session()
|
|
|
-
|
|
|
+# 显示切换模型
|
|
|
+@bp.route('/models', methods=['GET'])
|
|
|
+def get_models():
|
|
|
+ session = None
|
|
|
try:
|
|
|
- # 查询模型信息
|
|
|
- model = session.query(Models).filter_by(ModelID=model_id).first()
|
|
|
- if not model:
|
|
|
- return jsonify({'error': '未找到指定模型'}), 404
|
|
|
-
|
|
|
- # 加载模型
|
|
|
- with open(model.ModelFilePath, 'rb') as f:
|
|
|
- ML_model = pickle.load(f)
|
|
|
-
|
|
|
- # 根据数据类型加载测试数据
|
|
|
- if model.Data_type == 'reflux':
|
|
|
- X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
|
|
|
- Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
|
|
|
- elif model.Data_type == 'reduce':
|
|
|
- X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
|
|
|
- Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
|
|
|
- else:
|
|
|
- return jsonify({'error': '不支持的数据类型'}), 400
|
|
|
-
|
|
|
- # 获取预测值
|
|
|
- y_pred = ML_model.predict(X_test)
|
|
|
-
|
|
|
- # 生成散点图数据
|
|
|
- scatter_data = [
|
|
|
- [float(true), float(pred)]
|
|
|
- for true, pred in zip(Y_test.iloc[:, 0], y_pred)
|
|
|
+ # 创建 session
|
|
|
+ Session = sessionmaker(bind=db.engine)
|
|
|
+ session = Session()
|
|
|
+
|
|
|
+ # 查询所有模型
|
|
|
+ models = session.query(Models).all()
|
|
|
+
|
|
|
+ logger.debug(f"Models found: {models}") # 打印查询的模型数据
|
|
|
+
|
|
|
+ if not models:
|
|
|
+ return jsonify({'message': 'No models found'}), 404
|
|
|
+
|
|
|
+ # 将模型数据转换为字典列表
|
|
|
+ models_list = [
|
|
|
+ {
|
|
|
+ 'ModelID': model.ModelID,
|
|
|
+ 'ModelName': model.Model_name,
|
|
|
+ 'ModelType': model.Model_type,
|
|
|
+ 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'Description': model.Description,
|
|
|
+ 'DatasetID': model.DatasetID,
|
|
|
+ 'ModelFilePath': model.ModelFilePath,
|
|
|
+ 'DataType': model.Data_type,
|
|
|
+ 'PerformanceScore': model.Performance_score
|
|
|
+ }
|
|
|
+ for model in models
|
|
|
]
|
|
|
-
|
|
|
- # 计算R²分数
|
|
|
- r2 = r2_score(Y_test, y_pred)
|
|
|
-
|
|
|
- # 获取数据范围,用于绘制对角线
|
|
|
- y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
|
|
|
- y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
|
|
|
-
|
|
|
- return jsonify({
|
|
|
- 'scatter_data': scatter_data,
|
|
|
- 'r2_score': float(r2),
|
|
|
- 'y_range': [float(y_min), float(y_max)],
|
|
|
- 'model_name': model.Model_name,
|
|
|
- 'model_type': model.Model_type
|
|
|
- }), 200
|
|
|
-
|
|
|
+
|
|
|
+ return jsonify(models_list), 200
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
|
|
|
- return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
|
|
|
-
|
|
|
+ return jsonify({'error': str(e)}), 400
|
|
|
finally:
|
|
|
- session.close()
|
|
|
+ if session:
|
|
|
+ session.close()
|
|
|
+
|