import sqlite3 from io import BytesIO import pickle from flask import Blueprint, request, jsonify, current_app, send_file from werkzeug.security import check_password_hash, generate_password_hash from werkzeug.utils import secure_filename from .model import predict, train_and_save_model, calculate_model_score import pandas as pd from . import db # 从 app 包导入 db 实例 from sqlalchemy.engine.reflection import Inspector from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux import os from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \ clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \ predict_to_Q, Q_to_t_ha, create_kriging from sqlalchemy.orm import sessionmaker import logging from sqlalchemy import text, select, MetaData, Table, func from .tasks import train_model_task from datetime import datetime from sklearn.metrics import r2_score # 配置日志 logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # 创建蓝图 (Blueprint),用于分离路由 bp = Blueprint('routes', __name__) # 密码加密 def hash_password(password): return generate_password_hash(password) def get_db(): """ 获取数据库连接 """ return sqlite3.connect(current_app.config['DATABASE']) # 添加一个新的辅助函数来检查数据集大小并触发训练 def check_and_trigger_training(session, dataset_type, dataset_df): """ 检查当前数据集大小是否跨越新的阈值点并触发训练 Args: session: 数据库会话 dataset_type: 数据集类型 ('reduce' 或 'reflux') dataset_df: 数据集 DataFrame Returns: tuple: (是否触发训练, 任务ID) """ try: # 根据数据集类型选择表 table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux # 获取当前记录数 current_count = session.query(func.count()).select_from(table).scalar() # 获取新增的记录数(从request.files中获取的DataFrame长度) new_records = len(dataset_df) # 需要从上层函数传入 # 计算新增数据前的记录数 previous_count = current_count - new_records # 设置阈值 THRESHOLD = current_app.config['THRESHOLD'] # 计算上一个阈值点(基于新增前的数据量) last_threshold = previous_count // THRESHOLD * THRESHOLD # 计算当前所在阈值点 current_threshold = current_count // THRESHOLD * THRESHOLD # 检查是否跨越了新的阈值点 if current_threshold > last_threshold and current_count >= THRESHOLD: # 触发异步训练任务 task = train_model_task.delay( model_type=current_app.config['DEFAULT_MODEL_TYPE'], model_name=f'auto_trained_{dataset_type}_{current_threshold}', model_description=f'Auto trained model at {current_threshold} records threshold', data_type=dataset_type ) return True, task.id return False, None except Exception as e: logging.error(f"检查并触发训练失败: {str(e)}") return False, None @bp.route('/upload-dataset', methods=['POST']) def upload_dataset(): # 创建 session Session = sessionmaker(bind=db.engine) session = Session() try: if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '' or not allowed_file(file.filename): return jsonify({'error': 'No selected file or invalid file type'}), 400 dataset_name = request.form.get('dataset_name') dataset_description = request.form.get('dataset_description', 'No description provided') dataset_type = request.form.get('dataset_type') if not dataset_type: return jsonify({'error': 'Dataset type is required'}), 400 new_dataset = Datasets( Dataset_name=dataset_name, Dataset_description=dataset_description, Row_count=0, Status='Datasets_upgraded', Dataset_type=dataset_type, Uploaded_at=datetime.now() ) session.add(new_dataset) session.commit() unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx" upload_folder = current_app.config['UPLOAD_FOLDER'] file_path = os.path.join(upload_folder, unique_filename) file.save(file_path) dataset_df = pd.read_excel(file_path) new_dataset.Row_count = len(dataset_df) new_dataset.Status = 'excel_file_saved success' session.commit() # 处理列名 dataset_df = clean_column_names(dataset_df) dataset_df = rename_columns_for_model(dataset_df, dataset_type) column_types = infer_column_types(dataset_df) dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types) insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class) # 根据 dataset_type 决定插入到哪个已有表 if dataset_type == 'reduce': insert_data_into_existing_table(session, dataset_df, CurrentReduce) elif dataset_type == 'reflux': insert_data_into_existing_table(session, dataset_df, CurrentReflux) session.commit() # 在完成数据插入后,检查是否需要触发训练 training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df) response_data = { 'message': f'Dataset {dataset_name} uploaded successfully!', 'dataset_id': new_dataset.Dataset_ID, 'filename': unique_filename, 'training_triggered': training_triggered } if training_triggered: response_data['task_id'] = task_id response_data['message'] += ' Auto-training has been triggered.' return jsonify(response_data), 201 except Exception as e: session.rollback() logging.error('Failed to process the dataset upload:', exc_info=True) return jsonify({'error': str(e)}), 500 finally: # 确保 session 总是被关闭 if session: session.close() @bp.route('/train-and-save-model', methods=['POST']) def train_and_save_model_endpoint(): # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() data = request.get_json() # 从请求中解析参数 model_type = data.get('model_type') model_name = data.get('model_name') model_description = data.get('model_description') data_type = data.get('data_type') dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供 try: # 调用训练和保存模型的函数 result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id) model_id = result[1] if result else None # 计算模型评分 if model_id: model_info = session.query(Models).filter(Models.ModelID == model_id).first() if model_info: score = calculate_model_score(model_info) # 更新模型评分 model_info.Performance_score = score session.commit() result = {'model_id': model_id, 'model_score': score} # 返回成功响应 return jsonify({ 'message': 'Model trained and saved successfully', 'result': result }), 200 except Exception as e: session.rollback() logging.error('Failed to process the model training:', exc_info=True) return jsonify({ 'error': 'Failed to train and save model', 'message': str(e) }), 500 finally: session.close() @bp.route('/predict', methods=['POST']) def predict_route(): # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() try: data = request.get_json() model_id = data.get('model_id') # 提取模型名称 parameters = data.get('parameters', {}) # 提取所有变量 # 根据model_id获取模型Data_type model_info = session.query(Models).filter(Models.ModelID == model_id).first() if not model_info: return jsonify({'error': 'Model not found'}), 404 data_type = model_info.Data_type input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame # 如果为reduce,则不需要传入target_ph if data_type == 'reduce': # 获取传入的init_ph、target_ph参数 init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误 target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误 # 从输入数据中删除'target_pH'列 input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错 input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段 predictions = predict(session, input_data_rename, model_id) # 调用预测函数 if data_type == 'reduce': predictions = predictions[0] # 将预测结果转换为Q Q = predict_to_Q(predictions, init_ph, target_ph) predictions = Q_to_t_ha(Q) # 将Q转换为t/ha print(predictions) return jsonify({'result': predictions}), 200 except Exception as e: logging.error('Failed to predict:', exc_info=True) return jsonify({'error': str(e)}), 400 # 为指定模型计算评分Performance_score,需要提供model_id @bp.route('/score-model/', methods=['POST']) def score_model(model_id): # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() try: model_info = session.query(Models).filter(Models.ModelID == model_id).first() if not model_info: return jsonify({'error': 'Model not found'}), 404 # 计算模型评分 score = calculate_model_score(model_info) # 更新模型记录中的评分 model_info.Performance_score = score session.commit() return jsonify({'message': 'Model scored successfully', 'score': score}), 200 except Exception as e: logging.error('Failed to process the dataset upload:', exc_info=True) return jsonify({'error': str(e)}), 400 finally: session.close() @bp.route('/delete-dataset/', methods=['DELETE']) def delete_dataset_endpoint(dataset_id): """ 删除数据集的API接口 @param dataset_id: 要删除的数据集ID @return: JSON响应 """ # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() try: # 查询数据集 dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first() if not dataset: return jsonify({'error': '未找到数据集'}), 404 # 检查是否有模型使用了该数据集 models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all() if models_using_dataset: models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset] return jsonify({ 'error': '无法删除数据集,因为以下模型正在使用它', 'models': models_info }), 400 # 删除Excel文件 filename = f"dataset_{dataset.Dataset_ID}.xlsx" file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename) if os.path.exists(file_path): try: os.remove(file_path) except OSError as e: logger.error(f'删除文件失败: {str(e)}') return jsonify({'error': f'删除文件失败: {str(e)}'}), 500 # 删除数据表 table_name = f"dataset_{dataset.Dataset_ID}" session.execute(text(f"DROP TABLE IF EXISTS {table_name}")) # 删除数据集记录 session.delete(dataset) session.commit() return jsonify({ 'message': '数据集删除成功', 'deleted_files': [filename] }), 200 except Exception as e: session.rollback() logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True) return jsonify({'error': str(e)}), 500 finally: session.close() @bp.route('/tables', methods=['GET']) def list_tables(): engine = db.engine # 使用 db 实例的 engine inspector = Inspector.from_engine(engine) # 创建 Inspector 对象 table_names = inspector.get_table_names() # 获取所有表名 return jsonify(table_names) # 以 JSON 形式返回表名列表 @bp.route('/models/', methods=['GET']) def get_model(model_id): """ 获取单个模型信息的API接口 @param model_id: 模型ID @return: JSON响应 """ Session = sessionmaker(bind=db.engine) session = Session() try: model = session.query(Models).filter_by(ModelID=model_id).first() if model: return jsonify({ 'ModelID': model.ModelID, 'Model_name': model.Model_name, 'Model_type': model.Model_type, 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'), 'Description': model.Description, 'Performance_score': float(model.Performance_score) if model.Performance_score else None, 'Data_type': model.Data_type }) else: return jsonify({'message': '未找到模型'}), 404 except Exception as e: logger.error(f'获取模型信息失败: {str(e)}') return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500 finally: session.close() @bp.route('/models', methods=['GET']) def get_all_models(): """ 获取所有模型信息的API接口 @return: JSON响应 """ Session = sessionmaker(bind=db.engine) session = Session() try: models = session.query(Models).all() if models: result = [ { 'ModelID': model.ModelID, 'Model_name': model.Model_name, 'Model_type': model.Model_type, 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'), 'Description': model.Description, 'Performance_score': float(model.Performance_score) if model.Performance_score else None, 'Data_type': model.Data_type } for model in models ] return jsonify(result) else: return jsonify({'message': '未找到任何模型'}), 404 except Exception as e: logger.error(f'获取所有模型信息失败: {str(e)}') return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500 finally: session.close() @bp.route('/model-parameters', methods=['GET']) def get_all_model_parameters(): """ 获取所有模型参数的API接口 @return: JSON响应 """ Session = sessionmaker(bind=db.engine) session = Session() try: parameters = session.query(ModelParameters).all() if parameters: result = [ { 'ParamID': param.ParamID, 'ModelID': param.ModelID, 'ParamName': param.ParamName, 'ParamValue': param.ParamValue } for param in parameters ] return jsonify(result) else: return jsonify({'message': '未找到任何参数'}), 404 except Exception as e: logger.error(f'获取所有模型参数失败: {str(e)}') return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500 finally: session.close() @bp.route('/models//parameters', methods=['GET']) def get_model_parameters(model_id): try: model = Models.query.filter_by(ModelID=model_id).first() if model: # 获取该模型的所有参数 parameters = [ { 'ParamID': param.ParamID, 'ParamName': param.ParamName, 'ParamValue': param.ParamValue } for param in model.parameters ] # 返回模型参数信息 return jsonify({ 'ModelID': model.ModelID, 'ModelName': model.ModelName, 'ModelType': model.ModelType, 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'), 'Description': model.Description, 'Parameters': parameters }) else: return jsonify({'message': 'Model not found'}), 404 except Exception as e: return jsonify({'error': 'Internal server error', 'message': str(e)}), 500 # 定义添加数据库记录的 API 接口 @bp.route('/add_item', methods=['POST']) def add_item(): """ 接收 JSON 格式的请求体,包含表名和要插入的数据。 尝试将数据插入到指定的表中,并进行字段查重。 :return: """ try: # 确保请求体是 JSON 格式 data = request.get_json() if not data: raise ValueError("No JSON data provided") table_name = data.get('table') item_data = data.get('item') if not table_name or not item_data: return jsonify({'error': 'Missing table name or item data'}), 400 # 定义各个表的字段查重规则 duplicate_check_rules = { 'users': ['email', 'username'], 'products': ['product_code'], 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'], 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'], # 其他表和规则 } # 获取该表的查重字段 duplicate_columns = duplicate_check_rules.get(table_name) if not duplicate_columns: return jsonify({'error': 'No duplicate check rule for this table'}), 400 # 动态构建查询条件,逐一检查是否有重复数据 condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns]) duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1" result = db.session.execute(text(duplicate_query), item_data).fetchone() if result: return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409 # 动态构建 SQL 语句,进行插入操作 columns = ', '.join(item_data.keys()) placeholders = ', '.join([f":{key}" for key in item_data.keys()]) sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" # 直接执行插入操作,无需显式的事务管理 db.session.execute(text(sql), item_data) # 提交事务 db.session.commit() # 返回成功响应 return jsonify({'success': True, 'message': 'Item added successfully'}), 201 except ValueError as e: return jsonify({'error': str(e)}), 400 except KeyError as e: return jsonify({'error': f'Missing data field: {e}'}), 400 except sqlite3.IntegrityError as e: return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409 except sqlite3.Error as e: return jsonify({'error': '数据库错误', 'details': str(e)}), 500 @bp.route('/delete_item', methods=['POST']) def delete_item(): """ 删除数据库记录的 API 接口 """ data = request.get_json() table_name = data.get('table') condition = data.get('condition') # 检查表名和条件是否提供 if not table_name or not condition: return jsonify({ "success": False, "message": "缺少表名或条件参数" }), 400 # 尝试从条件字符串中解析键和值 try: key, value = condition.split('=') key = key.strip() # 去除多余的空格 value = value.strip().strip("'\"") # 去除多余的空格和引号 except ValueError: return jsonify({ "success": False, "message": "条件格式错误,应为 'key=value'" }), 400 # 准备 SQL 删除语句 sql = f"DELETE FROM {table_name} WHERE {key} = :value" try: # 使用 SQLAlchemy 执行删除 with db.session.begin(): result = db.session.execute(text(sql), {"value": value}) # 检查是否有记录被删除 if result.rowcount == 0: return jsonify({ "success": False, "message": "未找到符合条件的记录" }), 404 return jsonify({ "success": True, "message": "记录删除成功" }), 200 except Exception as e: return jsonify({ "success": False, "message": f"删除失败: {e}" }), 500 # 定义修改数据库记录的 API 接口 @bp.route('/update_item', methods=['PUT']) def update_record(): """ 接收 JSON 格式的请求体,包含表名和更新的数据。 尝试更新指定的记录。 """ data = request.get_json() # 检查必要的数据是否提供 if not data or 'table' not in data or 'item' not in data: return jsonify({ "success": False, "message": "请求数据不完整" }), 400 table_name = data['table'] item = data['item'] # 假设 item 的第一个键是 ID id_key = next(iter(item.keys())) # 获取第一个键 record_id = item.get(id_key) if not record_id: return jsonify({ "success": False, "message": "缺少记录 ID" }), 400 # 获取更新的字段和值 updates = {key: value for key, value in item.items() if key != id_key} if not updates: return jsonify({ "success": False, "message": "没有提供需要更新的字段" }), 400 # 动态构建 SQL set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()]) sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value" # 添加 ID 到参数 updates['id_value'] = record_id try: # 使用 SQLAlchemy 执行更新 with db.session.begin(): result = db.session.execute(text(sql), updates) # 检查是否有更新的记录 if result.rowcount == 0: return jsonify({ "success": False, "message": "未找到要更新的记录" }), 404 return jsonify({ "success": True, "message": "数据更新成功" }), 200 except Exception as e: # 捕获所有异常并返回 return jsonify({ "success": False, "message": f"更新失败: {str(e)}" }), 500 # 定义查询数据库记录的 API 接口 @bp.route('/search/record', methods=['GET']) def sql_search(): """ 接收 JSON 格式的请求体,包含表名和要查询的 ID。 尝试查询指定 ID 的记录并返回结果。 :return: """ try: data = request.get_json() # 表名 sql_table = data['table'] # 要搜索的 ID Id = data['id'] # 连接到数据库 cur = db.cursor() # 构造查询语句 sql = f"SELECT * FROM {sql_table} WHERE id = ?" # 执行查询 cur.execute(sql, (Id,)) # 获取查询结果 rows = cur.fetchall() column_names = [desc[0] for desc in cur.description] # 检查是否有结果 if not rows: return jsonify({'error': '未查找到对应数据。'}), 400 # 构造响应数据 results = [] for row in rows: result = {column_names[i]: row[i] for i in range(len(row))} results.append(result) # 关闭游标和数据库连接 cur.close() db.close() # 返回 JSON 响应 return jsonify(results), 200 except sqlite3.Error as e: # 如果发生数据库错误,返回错误信息 return jsonify({'error': str(e)}), 400 except KeyError as e: # 如果请求数据中缺少必要的键,返回错误信息 return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400 # 定义提供数据库列表,用于展示表格的 API 接口 @bp.route('/table', methods=['POST']) def get_table(): data = request.get_json() table_name = data.get('table') if not table_name: return jsonify({'error': '需要表名'}), 400 try: # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() # 动态获取表的元数据 metadata = MetaData() table = Table(table_name, metadata, autoload_with=db.engine) # 从数据库中查询所有记录 query = select(table) result = session.execute(query).fetchall() # 将结果转换为列表字典形式 rows = [dict(zip([column.name for column in table.columns], row)) for row in result] # 获取列名 headers = [column.name for column in table.columns] return jsonify(rows=rows, headers=headers), 200 except Exception as e: return jsonify({'error': str(e)}), 400 finally: # 关闭 session session.close() @bp.route('/train-model-async', methods=['POST']) def train_model_async(): """ 异步训练模型的API接口 """ try: data = request.get_json() # 从请求中获取参数 model_type = data.get('model_type') model_name = data.get('model_name') model_description = data.get('model_description') data_type = data.get('data_type') dataset_id = data.get('dataset_id', None) # 验证必要参数 if not all([model_type, model_name, data_type]): return jsonify({ 'error': 'Missing required parameters' }), 400 # 如果提供了dataset_id,验证数据集是否存在 if dataset_id: Session = sessionmaker(bind=db.engine) session = Session() try: dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first() if not dataset: return jsonify({ 'error': f'Dataset with ID {dataset_id} not found' }), 404 finally: session.close() # 启动异步任务 task = train_model_task.delay( model_type=model_type, model_name=model_name, model_description=model_description, data_type=data_type, dataset_id=dataset_id ) # 返回任务ID return jsonify({ 'task_id': task.id, 'message': 'Model training started' }), 202 except Exception as e: logging.error('Failed to start async training task:', exc_info=True) return jsonify({ 'error': str(e) }), 500 @bp.route('/task-status/', methods=['GET']) def get_task_status(task_id): """ 获取异步任务状态的API接口 """ try: task = train_model_task.AsyncResult(task_id) if task.state == 'PENDING': response = { 'state': task.state, 'status': 'Task is waiting for execution' } elif task.state == 'FAILURE': response = { 'state': task.state, 'status': 'Task failed', 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info) } elif task.state == 'SUCCESS': response = { 'state': task.state, 'status': 'Task completed successfully', 'result': task.get() } else: response = { 'state': task.state, 'status': 'Task is in progress' } return jsonify(response), 200 except Exception as e: return jsonify({ 'error': str(e) }), 500 @bp.route('/delete-model/', methods=['DELETE']) def delete_model_route(model_id): # 将URL参数转换为布尔值 delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true' # 调用原始函数 return delete_model(model_id, delete_dataset=delete_dataset_param) def delete_model(model_id, delete_dataset=False): """ 删除指定模型的API接口 @param model_id: 要删除的模型ID @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False @return: JSON响应 """ Session = sessionmaker(bind=db.engine) session = Session() try: # 查询模型信息 model = session.query(Models).filter_by(ModelID=model_id).first() if not model: return jsonify({'error': '未找到指定模型'}), 404 dataset_id = model.DatasetID # 1. 先删除模型记录 session.delete(model) session.commit() # 2. 删除模型文件 model_path = model.ModelFilePath try: if os.path.exists(model_path): os.remove(model_path) else: # 如果删除文件失败,回滚数据库操作 session.rollback() logger.warning(f'模型文件不存在: {model_path}') except OSError as e: # 如果删除文件失败,回滚数据库操作 session.rollback() logger.error(f'删除模型文件失败: {str(e)}') return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500 # 3. 如果需要删除关联的数据集 if delete_dataset and dataset_id: try: dataset_response = delete_dataset_endpoint(dataset_id) if not isinstance(dataset_response, tuple) or dataset_response[1] != 200: # 如果删除数据集失败,回滚之前的操作 session.rollback() return jsonify({ 'error': '删除关联数据集失败', 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0]) }), 500 except Exception as e: session.rollback() logger.error(f'删除关联数据集失败: {str(e)}') return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500 response_data = { 'message': '模型删除成功', 'deleted_files': [model_path] } if delete_dataset: response_data['dataset_info'] = { 'dataset_id': dataset_id, 'message': '关联数据集已删除' } return jsonify(response_data), 200 except Exception as e: session.rollback() logger.error(f'删除模型 {model_id} 失败:', exc_info=True) return jsonify({'error': str(e)}), 500 finally: session.close() # 添加一个新的API端点来清空指定数据集 @bp.route('/clear-dataset/', methods=['DELETE']) def clear_dataset(data_type): """ 清空指定类型的数据集并递增计数 @param data_type: 数据集类型 ('reduce' 或 'reflux') @return: JSON响应 """ # 创建 sessionmaker 实例 Session = sessionmaker(bind=db.engine) session = Session() try: # 根据数据集类型选择表 if data_type == 'reduce': table = CurrentReduce table_name = 'current_reduce' elif data_type == 'reflux': table = CurrentReflux table_name = 'current_reflux' else: return jsonify({'error': '无效的数据集类型'}), 400 # 清空表内容 session.query(table).delete() # 重置自增主键计数器 session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'")) session.commit() return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200 except Exception as e: session.rollback() return jsonify({'error': str(e)}), 500 finally: session.close() @bp.route('/login', methods=['POST']) def login_user(): # 获取前端传来的数据 data = request.get_json() name = data.get('name') # 用户名 password = data.get('password') # 密码 logger.info(f"Login request received: name={name}") # 检查用户名和密码是否为空 if not name or not password: logger.warning("用户名和密码不能为空") return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400 try: # 查询数据库验证用户名 query = "SELECT * FROM users WHERE name = :name" conn = get_db() user = conn.execute(query, {"name": name}).fetchone() if not user: logger.warning(f"用户名 '{name}' 不存在") return jsonify({"success": False, "message": "用户名不存在"}), 400 # 获取数据库中存储的密码(假设密码是哈希存储的) stored_password = user[2] # 假设密码存储在数据库的第三列 user_id = user[0] # 假设 id 存储在数据库的第一列 # 校验密码是否正确 if check_password_hash(stored_password, password): logger.info(f"User '{name}' logged in successfully.") return jsonify({ "success": True, "message": "登录成功", "userId": user_id # 返回用户 ID }) else: logger.warning(f"Invalid password for user '{name}'") return jsonify({"success": False, "message": "用户名或密码错误"}), 400 except Exception as e: # 记录错误日志并返回错误信息 logger.error(f"Error during login: {e}", exc_info=True) return jsonify({"success": False, "message": "登录失败"}), 500 # 更新用户信息接口 @bp.route('/update_user', methods=['POST']) def update_user(): # 获取前端传来的数据 data = request.get_json() # 打印收到的请求数据 current_app.logger.info(f"Received data: {data}") user_id = data.get('userId') # 用户ID name = data.get('name') # 用户名 old_password = data.get('oldPassword') # 旧密码 new_password = data.get('newPassword') # 新密码 logger.info(f"Update request received: user_id={user_id}, name={name}") # 校验传入的用户名和密码是否为空 if not name or not old_password: logger.warning("用户名和旧密码不能为空") return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400 # 新密码和旧密码不能相同 if new_password and old_password == new_password: logger.warning(f"新密码与旧密码相同:{name}") return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400 try: # 查询数据库验证用户ID query = "SELECT * FROM users WHERE id = :user_id" conn = get_db() user = conn.execute(query, {"user_id": user_id}).fetchone() if not user: logger.warning(f"用户ID '{user_id}' 不存在") return jsonify({"success": False, "message": "用户不存在"}), 400 # 获取数据库中存储的密码(假设密码是哈希存储的) stored_password = user[2] # 假设密码存储在数据库的第三列 # 校验旧密码是否正确 if not check_password_hash(stored_password, old_password): logger.warning(f"旧密码错误:{name}") return jsonify({"success": False, "message": "旧密码错误"}), 400 # 如果新密码非空,则更新新密码 if new_password: hashed_new_password = hash_password(new_password) update_query = "UPDATE users SET password = :new_password WHERE id = :user_id" conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id}) conn.commit() logger.info(f"User ID '{user_id}' password updated successfully.") # 如果用户名发生更改,则更新用户名 if name != user[1]: update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id" conn.execute(update_name_query, {"new_name": name, "user_id": user_id}) conn.commit() logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.") return jsonify({"success": True, "message": "用户信息更新成功"}) except Exception as e: # 记录错误日志并返回错误信息 logger.error(f"Error updating user: {e}", exc_info=True) return jsonify({"success": False, "message": "更新失败"}), 500 # 注册用户 @bp.route('/register', methods=['POST']) def register_user(): # 获取前端传来的数据 data = request.get_json() name = data.get('name') # 用户名 password = data.get('password') # 密码 logger.info(f"Register request received: name={name}") # 检查用户名和密码是否为空 if not name or not password: logger.warning("用户名和密码不能为空") return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400 # 动态获取数据库表的列名 columns = get_column_names('users') logger.info(f"Database columns for 'users' table: {columns}") # 检查前端传来的数据是否包含数据库表中所有的必填字段 for column in ['name', 'password']: if column not in columns: logger.error(f"缺少必填字段:{column}") return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400 # 对密码进行哈希处理 hashed_password = hash_password(password) logger.info(f"Password hashed for user: {name}") # 插入到数据库 try: # 检查用户是否已经存在 query = "SELECT * FROM users WHERE name = :name" conn = get_db() user = conn.execute(query, {"name": name}).fetchone() if user: logger.warning(f"用户名 '{name}' 已存在") return jsonify({"success": False, "message": "用户名已存在"}), 400 # 向数据库插入数据 query = "INSERT INTO users (name, password) VALUES (:name, :password)" conn.execute(query, {"name": name, "password": hashed_password}) conn.commit() logger.info(f"User '{name}' registered successfully.") return jsonify({"success": True, "message": "注册成功"}) except Exception as e: # 记录错误日志并返回错误信息 logger.error(f"Error registering user: {e}", exc_info=True) return jsonify({"success": False, "message": "注册失败"}), 500 def get_column_names(table_name): """ 动态获取数据库表的列名。 """ try: conn = get_db() query = f"PRAGMA table_info({table_name});" result = conn.execute(query).fetchall() conn.close() return [row[1] for row in result] # 第二列是列名 except Exception as e: logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True) return [] # 导出数据 @bp.route('/export_data', methods=['GET']) def export_data(): table_name = request.args.get('table') file_format = request.args.get('format', 'excel').lower() if not table_name: return jsonify({'error': '缺少表名参数'}), 400 if not table_name.isidentifier(): return jsonify({'error': '无效的表名'}), 400 try: conn = get_db() query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;" table_exists = conn.execute(query, (table_name,)).fetchone() if not table_exists: return jsonify({'error': f"表 {table_name} 不存在"}), 404 query = f"SELECT * FROM {table_name};" df = pd.read_sql(query, conn) output = BytesIO() if file_format == 'csv': df.to_csv(output, index=False, encoding='utf-8') output.seek(0) return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv') elif file_format == 'excel': df.to_excel(output, index=False, engine='openpyxl') output.seek(0) return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx', mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet') else: return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400 except Exception as e: logger.error(f"Error in export_data: {e}", exc_info=True) return jsonify({'error': str(e)}), 500 # 导入数据接口 @bp.route('/import_data', methods=['POST']) def import_data(): logger.debug("Import data endpoint accessed.") if 'file' not in request.files: logger.error("No file in request.") return jsonify({'success': False, 'message': '文件缺失'}), 400 file = request.files['file'] table_name = request.form.get('table') if not table_name: logger.error("Missing table name parameter.") return jsonify({'success': False, 'message': '缺少表名参数'}), 400 if file.filename == '': logger.error("No file selected.") return jsonify({'success': False, 'message': '未选择文件'}), 400 try: # 保存文件到临时路径 temp_path = os.path.join(current_app.config['UPLOAD_FOLDER'], secure_filename(file.filename)) file.save(temp_path) logger.debug(f"File saved to temporary path: {temp_path}") # 根据文件类型读取文件 if file.filename.endswith('.xlsx'): df = pd.read_excel(temp_path) elif file.filename.endswith('.csv'): df = pd.read_csv(temp_path) else: logger.error("Unsupported file format.") return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400 # 获取数据库列名 db_columns = get_column_names(table_name) if 'id' in db_columns: db_columns.remove('id') # 假设 id 列是自增的,不需要处理 if not set(db_columns).issubset(set(df.columns)): logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}") return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400 # 清洗数据并删除空值行 df_cleaned = df[db_columns].dropna() # 统一数据类型,避免 int 和 float 合并问题 df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce') # 获取现有的数据 conn = get_db() with conn: existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn) # 查找重复数据 duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner') # 如果有重复数据,删除它们 df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)] logger.warning(f"Duplicate data detected and removed: {duplicates}") # 获取导入前后的数据量 total_data = len(df_cleaned) + len(duplicates) new_data = len(df_cleaned) duplicate_data = len(duplicates) # 导入不重复的数据 df_cleaned.to_sql(table_name, conn, if_exists='append', index=False) logger.debug(f"Imported {new_data} new records into the database.") # 删除临时文件 os.remove(temp_path) logger.debug(f"Temporary file removed: {temp_path}") # 返回结果 return jsonify({ 'success': True, 'message': '数据导入成功', 'total_data': total_data, 'new_data': new_data, 'duplicate_data': duplicate_data }), 200 except Exception as e: logger.error(f"Import failed: {e}", exc_info=True) return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500 # 模板下载接口 @bp.route('/download_template', methods=['GET']) def download_template(): """ 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。 """ table_name = request.args.get('table') if not table_name: return jsonify({'error': '表名参数缺失'}), 400 columns = get_column_names(table_name) if not columns: return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404 # 不包括 ID 列 if 'id' in columns: columns.remove('id') df = pd.DataFrame(columns=columns) file_format = request.args.get('format', 'excel').lower() try: if file_format == 'csv': output = BytesIO() df.to_csv(output, index=False, encoding='utf-8') output.seek(0) return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv', mimetype='text/csv') else: output = BytesIO() df.to_excel(output, index=False, engine='openpyxl') output.seek(0) return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx', mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet') except Exception as e: logger.error(f"Failed to generate template: {e}", exc_info=True) return jsonify({'error': '生成模板文件失败'}), 500 @bp.route('/update-threshold', methods=['POST']) def update_threshold(): """ 更新训练阈值的API接口 @body_param threshold: 新的阈值值(整数) @return: JSON响应 """ try: data = request.get_json() new_threshold = data.get('threshold') # 验证新阈值 if not isinstance(new_threshold, (int, float)) or new_threshold <= 0: return jsonify({ 'error': '无效的阈值值,必须为正数' }), 400 # 更新当前应用的阈值配置 current_app.config['THRESHOLD'] = int(new_threshold) return jsonify({ 'success': True, 'message': f'阈值已更新为 {new_threshold}', 'new_threshold': new_threshold }) except Exception as e: logging.error(f"更新阈值失败: {str(e)}") return jsonify({ 'error': f'更新阈值失败: {str(e)}' }), 500 @bp.route('/get-threshold', methods=['GET']) def get_threshold(): """ 获取当前训练阈值的API接口 @return: JSON响应 """ try: current_threshold = current_app.config['THRESHOLD'] default_threshold = current_app.config['DEFAULT_THRESHOLD'] return jsonify({ 'current_threshold': current_threshold, 'default_threshold': default_threshold }) except Exception as e: logging.error(f"获取阈值失败: {str(e)}") return jsonify({ 'error': f'获取阈值失败: {str(e)}' }), 500 @bp.route('/set-current-dataset//', methods=['POST']) def set_current_dataset(data_type, dataset_id): """ 将指定数据集设置为current数据集 @param data_type: 数据集类型 ('reduce' 或 'reflux') @param dataset_id: 要设置为current的数据集ID @return: JSON响应 """ Session = sessionmaker(bind=db.engine) session = Session() try: # 验证数据集存在且类型匹配 dataset = session.query(Datasets)\ .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\ .first() if not dataset: return jsonify({ 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集' }), 404 # 根据数据类型选择表 if data_type == 'reduce': table = CurrentReduce table_name = 'current_reduce' elif data_type == 'reflux': table = CurrentReflux table_name = 'current_reflux' else: return jsonify({'error': '无效的数据集类型'}), 400 # 清空current表 session.query(table).delete() # 重置自增主键计数器 session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'")) # 从指定数据集复制数据到current表 dataset_table_name = f"dataset_{dataset_id}" copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}") session.execute(copy_sql) session.commit() return jsonify({ 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}', 'dataset_id': dataset_id, 'dataset_name': dataset.Dataset_name, 'row_count': dataset.Row_count }), 200 except Exception as e: session.rollback() logger.error(f'设置current数据集失败: {str(e)}') return jsonify({'error': str(e)}), 500 finally: session.close() @bp.route('/get-model-history/', methods=['GET']) def get_model_history(data_type): """ 获取模型训练历史数据的API接口 @param data_type: 数据集类型 ('reduce' 或 'reflux') @return: JSON响应,包含时间序列的模型性能数据 """ Session = sessionmaker(bind=db.engine) session = Session() try: # 查询所有自动生成的数据集,按时间排序 datasets = session.query(Datasets).filter( Datasets.Dataset_type == data_type, Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}" ).order_by(Datasets.Uploaded_at).all() history_data = [] for dataset in datasets: # 查找对应的自动训练模型 model = session.query(Models).filter( Models.DatasetID == dataset.Dataset_ID, Models.Model_name.like(f'auto_trained_{data_type}_%') ).first() if model and model.Performance_score is not None: # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区) created_at = model.Created_at.isoformat() if model.Created_at else None history_data.append({ 'dataset_id': dataset.Dataset_ID, 'row_count': dataset.Row_count, 'model_id': model.ModelID, 'model_name': model.Model_name, 'performance_score': float(model.Performance_score), 'timestamp': created_at }) # 按时间戳排序 history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '') # 构建返回数据,分离各个指标序列便于前端绘图 response_data = { 'data_type': data_type, 'timestamps': [item['timestamp'] for item in history_data], 'row_counts': [item['row_count'] for item in history_data], 'performance_scores': [item['performance_score'] for item in history_data], 'model_details': history_data # 保留完整数据供前端使用 } return jsonify(response_data), 200 except Exception as e: logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True) return jsonify({'error': str(e)}), 500 finally: session.close() @bp.route('/batch-delete-datasets', methods=['POST']) def batch_delete_datasets(): """ 批量删除数据集的API接口 @body_param dataset_ids: 要删除的数据集ID列表 @return: JSON响应 """ try: data = request.get_json() dataset_ids = data.get('dataset_ids', []) if not dataset_ids: return jsonify({'error': '未提供数据集ID列表'}), 400 results = { 'success': [], 'failed': [], 'protected': [] # 被模型使用的数据集 } for dataset_id in dataset_ids: try: # 调用单个删除接口 response = delete_dataset_endpoint(dataset_id) # 解析响应 if response[1] == 200: results['success'].append(dataset_id) elif response[1] == 400 and 'models' in response[0].json: # 数据集被模型保护 results['protected'].append({ 'id': dataset_id, 'models': response[0].json['models'] }) else: results['failed'].append({ 'id': dataset_id, 'reason': response[0].json.get('error', '删除失败') }) except Exception as e: logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}') results['failed'].append({ 'id': dataset_id, 'reason': str(e) }) # 构建响应消息 message = f"成功删除 {len(results['success'])} 个数据集" if results['protected']: message += f", {len(results['protected'])} 个数据集被保护" if results['failed']: message += f", {len(results['failed'])} 个数据集删除失败" return jsonify({ 'message': message, 'results': results }), 200 except Exception as e: logger.error(f'批量删除数据集失败: {str(e)}') return jsonify({'error': str(e)}), 500 @bp.route('/batch-delete-models', methods=['POST']) def batch_delete_models(): """ 批量删除模型的API接口 @body_param model_ids: 要删除的模型ID列表 @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False @return: JSON响应 """ try: data = request.get_json() model_ids = data.get('model_ids', []) delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true' if not model_ids: return jsonify({'error': '未提供模型ID列表'}), 400 results = { 'success': [], 'failed': [], 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集 } for model_id in model_ids: try: # 调用单个删除接口 response = delete_model(model_id, delete_dataset=delete_datasets) # 解析响应 if response[1] == 200: results['success'].append(model_id) # 如果删除了关联数据集,记录数据集ID if 'dataset_info' in response[0].json: results['datasets_deleted'].append( response[0].json['dataset_info']['dataset_id'] ) else: results['failed'].append({ 'id': model_id, 'reason': response[0].json.get('error', '删除失败') }) except Exception as e: logger.error(f'删除模型 {model_id} 失败: {str(e)}') results['failed'].append({ 'id': model_id, 'reason': str(e) }) # 构建响应消息 message = f"成功删除 {len(results['success'])} 个模型" if results['datasets_deleted']: message += f", {len(results['datasets_deleted'])} 个关联数据集" if results['failed']: message += f", {len(results['failed'])} 个模型删除失败" return jsonify({ 'message': message, 'results': results }), 200 except Exception as e: logger.error(f'批量删除模型失败: {str(e)}') return jsonify({'error': str(e)}), 500 @bp.route('/kriging_interpolation', methods=['POST']) def kriging_interpolation(): try: data = request.get_json() required = ['file_name', 'emission_column', 'points'] if not all(k in data for k in required): return jsonify({"error": "Missing parameters"}), 400 # 添加坐标顺序验证 points = data['points'] if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points): return jsonify({"error": "Invalid points format"}), 400 result = create_kriging( data['file_name'], data['emission_column'], data['points'] ) return jsonify(result) except Exception as e: return jsonify({"error": str(e)}), 500 @bp.route('/model-scatter-data/', methods=['GET']) def get_model_scatter_data(model_id): """ 获取指定模型的散点图数据(真实值vs预测值) @param model_id: 模型ID @return: JSON响应,包含散点图数据 """ Session = sessionmaker(bind=db.engine) session = Session() try: # 查询模型信息 model = session.query(Models).filter_by(ModelID=model_id).first() if not model: return jsonify({'error': '未找到指定模型'}), 404 # 加载模型 with open(model.ModelFilePath, 'rb') as f: ML_model = pickle.load(f) # 根据数据类型加载测试数据 if model.Data_type == 'reflux': X_test = pd.read_csv('uploads/data/X_test_reflux.csv') Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv') elif model.Data_type == 'reduce': X_test = pd.read_csv('uploads/data/X_test_reduce.csv') Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv') else: return jsonify({'error': '不支持的数据类型'}), 400 # 获取预测值 y_pred = ML_model.predict(X_test) # 生成散点图数据 scatter_data = [ [float(true), float(pred)] for true, pred in zip(Y_test.iloc[:, 0], y_pred) ] # 计算R²分数 r2 = r2_score(Y_test, y_pred) # 获取数据范围,用于绘制对角线 y_min = min(min(Y_test.iloc[:, 0]), min(y_pred)) y_max = max(max(Y_test.iloc[:, 0]), max(y_pred)) return jsonify({ 'scatter_data': scatter_data, 'r2_score': float(r2), 'y_range': [float(y_min), float(y_max)], 'model_name': model.Model_name, 'model_type': model.Model_type }), 200 except Exception as e: logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True) return jsonify({'error': f'获取数据失败: {str(e)}'}), 500 finally: session.close()