123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221 |
- import sqlite3
- from io import BytesIO
- import pickle
- from flask import Blueprint, request, jsonify, current_app, send_file
- from werkzeug.security import check_password_hash, generate_password_hash
- from werkzeug.utils import secure_filename
- from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
- import pandas as pd
- from . import db # 从 app 包导入 db 实例
- from sqlalchemy.engine.reflection import Inspector
- from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
- import os
- from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
- clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
- predict_to_Q, Q_to_t_ha, create_kriging
- from sqlalchemy.orm import sessionmaker
- import logging
- from sqlalchemy import text, select, MetaData, Table, func
- from .tasks import train_model_task
- from datetime import datetime
- from sklearn.metrics import r2_score
- # 配置日志
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- # 创建蓝图 (Blueprint),用于分离路由
- bp = Blueprint('routes', __name__)
- # 密码加密
- def hash_password(password):
- return generate_password_hash(password)
- def get_db():
- """ 获取数据库连接 """
- return sqlite3.connect(current_app.config['DATABASE'])
- # 添加一个新的辅助函数来检查数据集大小并触发训练
- def check_and_trigger_training(session, dataset_type, dataset_df):
- """
- 检查当前数据集大小是否跨越新的阈值点并触发训练
-
- Args:
- session: 数据库会话
- dataset_type: 数据集类型 ('reduce' 或 'reflux')
- dataset_df: 数据集 DataFrame
-
- Returns:
- tuple: (是否触发训练, 任务ID)
- """
- try:
- # 根据数据集类型选择表
- table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
-
- # 获取当前记录数
- current_count = session.query(func.count()).select_from(table).scalar()
-
- # 获取新增的记录数(从request.files中获取的DataFrame长度)
- new_records = len(dataset_df) # 需要从上层函数传入
-
- # 计算新增数据前的记录数
- previous_count = current_count - new_records
-
- # 根据数据集类型选择阈值
- if dataset_type == 'reduce':
- THRESHOLD = current_app.config['THRESHOLD_REDUCE']
- else: # reflux
- THRESHOLD = current_app.config['THRESHOLD_REFLUX']
-
- # 计算上一个阈值点(基于新增前的数据量)
- last_threshold = previous_count // THRESHOLD * THRESHOLD
- # 计算当前所在阈值点
- current_threshold = current_count // THRESHOLD * THRESHOLD
-
- # 检查是否跨越了新的阈值点
- if current_threshold > last_threshold and current_count >= THRESHOLD:
- # 触发异步训练任务
- task = train_model_task.delay(
- model_type=current_app.config['DEFAULT_MODEL_TYPE'],
- model_name=f'auto_trained_{dataset_type}_{current_threshold}',
- model_description=f'Auto trained model at {current_threshold} records threshold',
- data_type=dataset_type
- )
- return True, task.id
-
- return False, None
-
- except Exception as e:
- logging.error(f"检查并触发训练失败: {str(e)}")
- return False, None
- @bp.route('/upload-dataset', methods=['POST'])
- def upload_dataset():
- # 创建 session
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- if 'file' not in request.files:
- return jsonify({'error': 'No file part'}), 400
- file = request.files['file']
- if file.filename == '' or not allowed_file(file.filename):
- return jsonify({'error': 'No selected file or invalid file type'}), 400
- dataset_name = request.form.get('dataset_name')
- dataset_description = request.form.get('dataset_description', 'No description provided')
- dataset_type = request.form.get('dataset_type')
- if not dataset_type:
- return jsonify({'error': 'Dataset type is required'}), 400
- new_dataset = Datasets(
- Dataset_name=dataset_name,
- Dataset_description=dataset_description,
- Row_count=0,
- Status='Datasets_upgraded',
- Dataset_type=dataset_type,
- Uploaded_at=datetime.now()
- )
- session.add(new_dataset)
- session.commit()
- unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
- upload_folder = current_app.config['UPLOAD_FOLDER']
- file_path = os.path.join(upload_folder, unique_filename)
- file.save(file_path)
- dataset_df = pd.read_excel(file_path)
- new_dataset.Row_count = len(dataset_df)
- new_dataset.Status = 'excel_file_saved success'
- session.commit()
- # 处理列名
- dataset_df = clean_column_names(dataset_df)
- dataset_df = rename_columns_for_model(dataset_df, dataset_type)
- column_types = infer_column_types(dataset_df)
- dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
- insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
- # 去除上传数据集内部的重复项
- original_count = len(dataset_df)
- dataset_df = dataset_df.drop_duplicates()
- duplicates_in_file = original_count - len(dataset_df)
- # 检查与现有数据的重复
- duplicates_with_existing = 0
- if dataset_type in ['reduce', 'reflux']:
- # 确定表名
- table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
-
- # 从表加载现有数据
- existing_data = pd.read_sql_table(table_name, session.bind)
- if 'id' in existing_data.columns:
- existing_data = existing_data.drop('id', axis=1)
-
- # 确定用于比较的列
- compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
-
- # 计算重复行数
- original_df_len = len(dataset_df)
-
- # 使用concat和drop_duplicates找出非重复行
- all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
- duplicates_mask = all_data.duplicated(keep='first')
- duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
-
- # 保留非重复行
- dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
-
- logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
- # 检查与测试集的重叠
- test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
-
- # 如果有与测试集重叠的数据,从数据集中移除
- if test_overlap_count > 0:
- # 创建一个布尔掩码,标记不在重叠索引中的行
- mask = ~dataset_df.index.isin(test_overlap_indices)
- # 应用掩码,只保留不重叠的行
- dataset_df = dataset_df[mask]
- logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
- # 根据 dataset_type 决定插入到哪个已有表
- if dataset_type == 'reduce':
- insert_data_into_existing_table(session, dataset_df, CurrentReduce)
- elif dataset_type == 'reflux':
- insert_data_into_existing_table(session, dataset_df, CurrentReflux)
- session.commit()
- # 在完成数据插入后,检查是否需要触发训练
- training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
- response_data = {
- 'message': f'数据集 {dataset_name} 上传成功!',
- 'dataset_id': new_dataset.Dataset_ID,
- 'filename': unique_filename,
- 'training_triggered': training_triggered,
- 'data_stats': {
- 'original_count': original_count,
- 'duplicates_in_file': duplicates_in_file,
- 'duplicates_with_existing': duplicates_with_existing,
- 'test_overlap_count': test_overlap_count,
- 'final_count': len(dataset_df)
- }
- }
-
- if training_triggered:
- response_data['task_id'] = task_id
- response_data['message'] += ' 自动训练已触发。'
- # 添加去重信息到消息中
- if duplicates_with_existing > 0:
- response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
-
- # 添加测试集重叠信息到消息中
- if test_overlap_count > 0:
- response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
- return jsonify(response_data), 201
- except Exception as e:
- session.rollback()
- logging.error('Failed to process the dataset upload:', exc_info=True)
- return jsonify({'error': str(e)}), 500
-
- finally:
- # 确保 session 总是被关闭
- if session:
- session.close()
- @bp.route('/train-and-save-model', methods=['POST'])
- def train_and_save_model_endpoint():
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- data = request.get_json()
- # 从请求中解析参数
- model_type = data.get('model_type')
- model_name = data.get('model_name')
- model_description = data.get('model_description')
- data_type = data.get('data_type')
- dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
- try:
- # 调用训练和保存模型的函数
- result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
-
- model_id = result[1] if result else None
-
- # 计算模型评分
- if model_id:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if model_info:
- # 计算多种评分指标
- score_metrics = calculate_model_score(model_info)
- # 更新模型评分
- model_info.Performance_score = score_metrics['r2']
- # 添加新的评分指标到数据库
- model_info.MAE = score_metrics['mae']
- model_info.RMSE = score_metrics['rmse']
- # CV_score 已在 train_and_save_model 中设置,此处不再更新
- session.commit()
- result = {
- 'model_id': model_id,
- 'model_score': score_metrics['r2'],
- 'mae': score_metrics['mae'],
- 'rmse': score_metrics['rmse'],
- 'cv_score': result[3]
- }
- # 返回成功响应
- return jsonify({
- 'message': 'Model trained and saved successfully',
- 'result': result
- }), 200
- except Exception as e:
- session.rollback()
- logging.error('Failed to process the model training:', exc_info=True)
- return jsonify({
- 'error': 'Failed to train and save model',
- 'message': str(e)
- }), 500
- finally:
- session.close()
- @bp.route('/predict', methods=['POST'])
- def predict_route():
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- data = request.get_json()
- model_id = data.get('model_id') # 提取模型名称
- parameters = data.get('parameters', {}) # 提取所有变量
- # 根据model_id获取模型Data_type
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if not model_info:
- return jsonify({'error': 'Model not found'}), 404
- data_type = model_info.Data_type
- input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
- # 如果为reduce,则不需要传入target_ph
- if data_type == 'reduce':
- # 获取传入的init_ph、target_ph参数
- init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
- target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
- # 从输入数据中删除'target_pH'列
- input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
- input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
- predictions = predict(session, input_data_rename, model_id) # 调用预测函数
- if data_type == 'reduce':
- predictions = predictions[0]
- # 将预测结果转换为Q
- Q = predict_to_Q(predictions, init_ph, target_ph)
- predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
- print(predictions)
- return jsonify({'result': predictions}), 200
- except Exception as e:
- logging.error('Failed to predict:', exc_info=True)
- return jsonify({'error': str(e)}), 400
- # 为指定模型计算指标评分,需要提供model_id
- @bp.route('/score-model/<int:model_id>', methods=['POST'])
- def score_model(model_id):
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if not model_info:
- return jsonify({'error': 'Model not found'}), 404
- # 计算模型评分
- score_metrics = calculate_model_score(model_info)
- # 更新模型记录中的评分(不包括交叉验证得分)
- model_info.Performance_score = score_metrics['r2']
- model_info.MAE = score_metrics['mae']
- model_info.RMSE = score_metrics['rmse']
- session.commit()
- return jsonify({
- 'message': 'Model scored successfully',
- 'r2_score': score_metrics['r2'],
- 'mae': score_metrics['mae'],
- 'rmse': score_metrics['rmse'],
- }), 200
- except Exception as e:
- logging.error('Failed to process model scoring:', exc_info=True)
- return jsonify({'error': str(e)}), 400
- finally:
- session.close()
- @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
- def delete_dataset_endpoint(dataset_id):
- """
- 删除数据集的API接口
-
- @param dataset_id: 要删除的数据集ID
- @return: JSON响应
- """
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- # 查询数据集
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if not dataset:
- return jsonify({'error': '未找到数据集'}), 404
- # 检查是否有模型使用了该数据集
- models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
- if models_using_dataset:
- models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
- return jsonify({
- 'error': '无法删除数据集,因为以下模型正在使用它',
- 'models': models_info
- }), 400
- # 删除Excel文件
- filename = f"dataset_{dataset.Dataset_ID}.xlsx"
- file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
- if os.path.exists(file_path):
- try:
- os.remove(file_path)
- except OSError as e:
- logger.error(f'删除文件失败: {str(e)}')
- return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
- # 删除数据表
- table_name = f"dataset_{dataset.Dataset_ID}"
- session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
- # 删除数据集记录
- session.delete(dataset)
- session.commit()
- return jsonify({
- 'message': '数据集删除成功',
- 'deleted_files': [filename]
- }), 200
- except Exception as e:
- session.rollback()
- logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
- return jsonify({'error': str(e)}), 500
- finally:
- session.close()
- @bp.route('/tables', methods=['GET'])
- def list_tables():
- engine = db.engine # 使用 db 实例的 engine
- inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
- table_names = inspector.get_table_names() # 获取所有表名
- return jsonify(table_names) # 以 JSON 形式返回表名列表
- @bp.route('/models/<int:model_id>', methods=['GET'])
- def get_model(model_id):
- """
- 获取单个模型信息的API接口
-
- @param model_id: 模型ID
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- model = session.query(Models).filter_by(ModelID=model_id).first()
- if model:
- return jsonify({
- 'ModelID': model.ModelID,
- 'Model_name': model.Model_name,
- 'Model_type': model.Model_type,
- 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description,
- 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
- 'MAE': float(model.MAE) if model.MAE else None,
- 'CV_score': float(model.CV_score) if model.CV_score else None,
- 'RMSE': float(model.RMSE) if model.RMSE else None,
- 'Data_type': model.Data_type
- })
- else:
- return jsonify({'message': '未找到模型'}), 404
-
- except Exception as e:
- logger.error(f'获取模型信息失败: {str(e)}')
- return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/models', methods=['GET'])
- def get_all_models():
- """
- 获取所有模型信息的API接口
-
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- models = session.query(Models).all()
- if models:
- result = [
- {
- 'ModelID': model.ModelID,
- 'Model_name': model.Model_name,
- 'Model_type': model.Model_type,
- 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description,
- 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
- 'MAE': float(model.MAE) if model.MAE else None,
- 'CV_score': float(model.CV_score) if model.CV_score else None,
- 'RMSE': float(model.RMSE) if model.RMSE else None,
- 'Data_type': model.Data_type
- }
- for model in models
- ]
- return jsonify(result)
- else:
- return jsonify({'message': '未找到任何模型'}), 404
-
- except Exception as e:
- logger.error(f'获取所有模型信息失败: {str(e)}')
- return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/model-parameters', methods=['GET'])
- def get_all_model_parameters():
- """
- 获取所有模型参数的API接口
-
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- parameters = session.query(ModelParameters).all()
- if parameters:
- result = [
- {
- 'ParamID': param.ParamID,
- 'ModelID': param.ModelID,
- 'ParamName': param.ParamName,
- 'ParamValue': param.ParamValue
- }
- for param in parameters
- ]
- return jsonify(result)
- else:
- return jsonify({'message': '未找到任何参数'}), 404
-
- except Exception as e:
- logger.error(f'获取所有模型参数失败: {str(e)}')
- return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
- def get_model_parameters(model_id):
- try:
- model = Models.query.filter_by(ModelID=model_id).first()
- if model:
- # 获取该模型的所有参数
- parameters = [
- {
- 'ParamID': param.ParamID,
- 'ParamName': param.ParamName,
- 'ParamValue': param.ParamValue
- }
- for param in model.parameters
- ]
-
- # 返回模型参数信息
- return jsonify({
- 'ModelID': model.ModelID,
- 'ModelName': model.ModelName,
- 'ModelType': model.ModelType,
- 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description,
- 'Parameters': parameters
- })
- else:
- return jsonify({'message': 'Model not found'}), 404
- except Exception as e:
- return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
- @bp.route('/train-model-async', methods=['POST'])
- def train_model_async():
- """
- 异步训练模型的API接口
- """
- try:
- data = request.get_json()
-
- # 从请求中获取参数
- model_type = data.get('model_type')
- model_name = data.get('model_name')
- model_description = data.get('model_description')
- data_type = data.get('data_type')
- dataset_id = data.get('dataset_id', None)
-
- # 验证必要参数
- if not all([model_type, model_name, data_type]):
- return jsonify({
- 'error': 'Missing required parameters'
- }), 400
-
- # 如果提供了dataset_id,验证数据集是否存在
- if dataset_id:
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if not dataset:
- return jsonify({
- 'error': f'Dataset with ID {dataset_id} not found'
- }), 404
- finally:
- session.close()
-
- # 启动异步任务
- task = train_model_task.delay(
- model_type=model_type,
- model_name=model_name,
- model_description=model_description,
- data_type=data_type,
- dataset_id=dataset_id
- )
-
- # 返回任务ID
- return jsonify({
- 'task_id': task.id,
- 'message': 'Model training started'
- }), 202
-
- except Exception as e:
- logging.error('Failed to start async training task:', exc_info=True)
- return jsonify({
- 'error': str(e)
- }), 500
- @bp.route('/task-status/<task_id>', methods=['GET'])
- def get_task_status(task_id):
- """
- 获取异步任务状态的API接口
- """
- try:
- task = train_model_task.AsyncResult(task_id)
-
- if task.state == 'PENDING':
- response = {
- 'state': task.state,
- 'status': 'Task is waiting for execution'
- }
- elif task.state == 'FAILURE':
- response = {
- 'state': task.state,
- 'status': 'Task failed',
- 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
- }
- elif task.state == 'SUCCESS':
- response = {
- 'state': task.state,
- 'status': 'Task completed successfully',
- 'result': task.get()
- }
- else:
- response = {
- 'state': task.state,
- 'status': 'Task is in progress'
- }
-
- return jsonify(response), 200
-
- except Exception as e:
- return jsonify({
- 'error': str(e)
- }), 500
- @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
- def delete_model_route(model_id):
- # 将URL参数转换为布尔值
- delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
-
- # 调用原始函数
- return delete_model(model_id, delete_dataset=delete_dataset_param)
- def delete_model(model_id, delete_dataset=False):
- """
- 删除指定模型的API接口
-
- @param model_id: 要删除的模型ID
- @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 查询模型信息
- model = session.query(Models).filter_by(ModelID=model_id).first()
- if not model:
- return jsonify({'error': '未找到指定模型'}), 404
-
- dataset_id = model.DatasetID
-
- # 1. 先删除模型记录
- session.delete(model)
- session.commit()
-
- # 2. 删除模型文件
- model_path = model.ModelFilePath
- try:
- if os.path.exists(model_path):
- os.remove(model_path)
- else:
- # 如果删除文件失败,回滚数据库操作
- session.rollback()
- logger.warning(f'模型文件不存在: {model_path}')
- except OSError as e:
- # 如果删除文件失败,回滚数据库操作
- session.rollback()
- logger.error(f'删除模型文件失败: {str(e)}')
- return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
- # 3. 如果需要删除关联的数据集
- if delete_dataset and dataset_id:
- try:
- dataset_response = delete_dataset_endpoint(dataset_id)
- if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
- # 如果删除数据集失败,回滚之前的操作
- session.rollback()
- return jsonify({
- 'error': '删除关联数据集失败',
- 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
- }), 500
- except Exception as e:
- session.rollback()
- logger.error(f'删除关联数据集失败: {str(e)}')
- return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
- response_data = {
- 'message': '模型删除成功',
- 'deleted_files': [model_path]
- }
-
- if delete_dataset:
- response_data['dataset_info'] = {
- 'dataset_id': dataset_id,
- 'message': '关联数据集已删除'
- }
- return jsonify(response_data), 200
-
- except Exception as e:
- session.rollback()
- logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
- return jsonify({'error': str(e)}), 500
- finally:
- session.close()
- # 添加一个新的API端点来清空指定数据集
- @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
- def clear_dataset(data_type):
- """
- 清空指定类型的数据集并递增计数
- @param data_type: 数据集类型 ('reduce' 或 'reflux')
- @return: JSON响应
- """
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 根据数据集类型选择表
- if data_type == 'reduce':
- table = CurrentReduce
- table_name = 'current_reduce'
- elif data_type == 'reflux':
- table = CurrentReflux
- table_name = 'current_reflux'
- else:
- return jsonify({'error': '无效的数据集类型'}), 400
-
- # 清空表内容
- session.query(table).delete()
-
- # 重置自增主键计数器
- session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
-
- session.commit()
-
- return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
-
- except Exception as e:
- session.rollback()
- return jsonify({'error': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/update-threshold', methods=['POST'])
- def update_threshold():
- """
- 更新训练阈值的API接口
-
- @body_param threshold: 新的阈值值(整数)
- @body_param data_type: 数据类型 ('reduce' 或 'reflux')
- @return: JSON响应
- """
- try:
- data = request.get_json()
- new_threshold = data.get('threshold')
- data_type = data.get('data_type')
-
- # 验证新阈值
- if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
- return jsonify({
- 'error': '无效的阈值值,必须为正数'
- }), 400
-
- # 验证数据类型
- if data_type not in ['reduce', 'reflux']:
- return jsonify({
- 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
- }), 400
-
- # 更新当前应用的阈值配置
- if data_type == 'reduce':
- current_app.config['THRESHOLD_REDUCE'] = int(new_threshold)
- else: # reflux
- current_app.config['THRESHOLD_REFLUX'] = int(new_threshold)
-
- return jsonify({
- 'success': True,
- 'message': f'{data_type} 阈值已更新为 {new_threshold}',
- 'data_type': data_type,
- 'new_threshold': new_threshold
- })
-
- except Exception as e:
- logging.error(f"更新阈值失败: {str(e)}")
- return jsonify({
- 'error': f'更新阈值失败: {str(e)}'
- }), 500
- @bp.route('/get-threshold', methods=['GET'])
- def get_threshold():
- """
- 获取当前训练阈值的API接口
-
- @query_param data_type: 可选,数据类型 ('reduce' 或 'reflux')
- @return: JSON响应
- """
- try:
- data_type = request.args.get('data_type')
-
- if data_type and data_type not in ['reduce', 'reflux']:
- return jsonify({
- 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
- }), 400
-
- response = {}
-
- if data_type == 'reduce' or data_type is None:
- response['reduce'] = {
- 'current_threshold': current_app.config['THRESHOLD_REDUCE'],
- 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REDUCE']
- }
-
- if data_type == 'reflux' or data_type is None:
- response['reflux'] = {
- 'current_threshold': current_app.config['THRESHOLD_REFLUX'],
- 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REFLUX']
- }
-
- return jsonify(response)
-
- except Exception as e:
- logging.error(f"获取阈值失败: {str(e)}")
- return jsonify({
- 'error': f'获取阈值失败: {str(e)}'
- }), 500
- @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
- def set_current_dataset(data_type, dataset_id):
- """
- 将指定数据集设置为current数据集
-
- @param data_type: 数据集类型 ('reduce' 或 'reflux')
- @param dataset_id: 要设置为current的数据集ID
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 验证数据集存在且类型匹配
- dataset = session.query(Datasets)\
- .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
- .first()
-
- if not dataset:
- return jsonify({
- 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
- }), 404
-
- # 根据数据类型选择表
- if data_type == 'reduce':
- table = CurrentReduce
- table_name = 'current_reduce'
- elif data_type == 'reflux':
- table = CurrentReflux
- table_name = 'current_reflux'
- else:
- return jsonify({'error': '无效的数据集类型'}), 400
-
- # 清空current表
- session.query(table).delete()
-
- # 重置自增主键计数器
- session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
-
- # 从指定数据集复制数据到current表
- dataset_table_name = f"dataset_{dataset_id}"
- copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
- session.execute(copy_sql)
-
- session.commit()
-
- return jsonify({
- 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
- 'dataset_id': dataset_id,
- 'dataset_name': dataset.Dataset_name,
- 'row_count': dataset.Row_count
- }), 200
-
- except Exception as e:
- session.rollback()
- logger.error(f'设置current数据集失败: {str(e)}')
- return jsonify({'error': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
- def get_model_history(data_type):
- """
- 获取模型训练历史数据的API接口
-
- @param data_type: 数据集类型 ('reduce' 或 'reflux')
- @return: JSON响应,包含时间序列的模型性能数据
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 查询所有自动生成的数据集,按时间排序
- datasets = session.query(Datasets).filter(
- Datasets.Dataset_type == data_type,
- Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
- ).order_by(Datasets.Uploaded_at).all()
-
- history_data = []
- for dataset in datasets:
- # 查找对应的自动训练模型
- model = session.query(Models).filter(
- Models.DatasetID == dataset.Dataset_ID,
- Models.Model_name.like(f'auto_trained_{data_type}_%')
- ).first()
-
- if model and model.Performance_score is not None:
- # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
- created_at = model.Created_at.isoformat() if model.Created_at else None
-
- history_data.append({
- 'dataset_id': dataset.Dataset_ID,
- 'row_count': dataset.Row_count,
- 'model_id': model.ModelID,
- 'model_name': model.Model_name,
- 'performance_score': float(model.Performance_score),
- 'timestamp': created_at
- })
-
- # 按时间戳排序
- history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
-
- # 构建返回数据,分离各个指标序列便于前端绘图
- response_data = {
- 'data_type': data_type,
- 'timestamps': [item['timestamp'] for item in history_data],
- 'row_counts': [item['row_count'] for item in history_data],
- 'performance_scores': [item['performance_score'] for item in history_data],
- 'model_details': history_data # 保留完整数据供前端使用
- }
-
- return jsonify(response_data), 200
-
- except Exception as e:
- logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
- return jsonify({'error': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/batch-delete-datasets', methods=['POST'])
- def batch_delete_datasets():
- """
- 批量删除数据集的API接口
-
- @body_param dataset_ids: 要删除的数据集ID列表
- @return: JSON响应
- """
- try:
- data = request.get_json()
- dataset_ids = data.get('dataset_ids', [])
-
- if not dataset_ids:
- return jsonify({'error': '未提供数据集ID列表'}), 400
-
- results = {
- 'success': [],
- 'failed': [],
- 'protected': [] # 被模型使用的数据集
- }
-
- for dataset_id in dataset_ids:
- try:
- # 调用单个删除接口
- response = delete_dataset_endpoint(dataset_id)
-
- # 解析响应
- if response[1] == 200:
- results['success'].append(dataset_id)
- elif response[1] == 400 and 'models' in response[0].json:
- # 数据集被模型保护
- results['protected'].append({
- 'id': dataset_id,
- 'models': response[0].json['models']
- })
- else:
- results['failed'].append({
- 'id': dataset_id,
- 'reason': response[0].json.get('error', '删除失败')
- })
-
- except Exception as e:
- logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
- results['failed'].append({
- 'id': dataset_id,
- 'reason': str(e)
- })
-
- # 构建响应消息
- message = f"成功删除 {len(results['success'])} 个数据集"
- if results['protected']:
- message += f", {len(results['protected'])} 个数据集被保护"
- if results['failed']:
- message += f", {len(results['failed'])} 个数据集删除失败"
-
- return jsonify({
- 'message': message,
- 'results': results
- }), 200
-
- except Exception as e:
- logger.error(f'批量删除数据集失败: {str(e)}')
- return jsonify({'error': str(e)}), 500
- @bp.route('/batch-delete-models', methods=['POST'])
- def batch_delete_models():
- """
- 批量删除模型的API接口
-
- @body_param model_ids: 要删除的模型ID列表
- @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
- @return: JSON响应
- """
- try:
- data = request.get_json()
- model_ids = data.get('model_ids', [])
- delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
-
- if not model_ids:
- return jsonify({'error': '未提供模型ID列表'}), 400
-
- results = {
- 'success': [],
- 'failed': [],
- 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
- }
-
- for model_id in model_ids:
- try:
- # 调用单个删除接口
- response = delete_model(model_id, delete_dataset=delete_datasets)
-
- # 解析响应
- if response[1] == 200:
- results['success'].append(model_id)
- # 如果删除了关联数据集,记录数据集ID
- if 'dataset_info' in response[0].json:
- results['datasets_deleted'].append(
- response[0].json['dataset_info']['dataset_id']
- )
- else:
- results['failed'].append({
- 'id': model_id,
- 'reason': response[0].json.get('error', '删除失败')
- })
-
- except Exception as e:
- logger.error(f'删除模型 {model_id} 失败: {str(e)}')
- results['failed'].append({
- 'id': model_id,
- 'reason': str(e)
- })
-
- # 构建响应消息
- message = f"成功删除 {len(results['success'])} 个模型"
- if results['datasets_deleted']:
- message += f", {len(results['datasets_deleted'])} 个关联数据集"
- if results['failed']:
- message += f", {len(results['failed'])} 个模型删除失败"
-
- return jsonify({
- 'message': message,
- 'results': results
- }), 200
-
- except Exception as e:
- logger.error(f'批量删除模型失败: {str(e)}')
- return jsonify({'error': str(e)}), 500
- @bp.route('/kriging_interpolation', methods=['POST'])
- def kriging_interpolation():
- try:
- data = request.get_json()
- required = ['file_name', 'emission_column', 'points']
- if not all(k in data for k in required):
- return jsonify({"error": "Missing parameters"}), 400
- # 添加坐标顺序验证
- points = data['points']
- if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
- return jsonify({"error": "Invalid points format"}), 400
- result = create_kriging(
- data['file_name'],
- data['emission_column'],
- data['points']
- )
- return jsonify(result)
- except Exception as e:
- return jsonify({"error": str(e)}), 500
- @bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
- def get_model_scatter_data(model_id):
- """
- 获取指定模型的散点图数据(真实值vs预测值)
-
- @param model_id: 模型ID
- @return: JSON响应,包含散点图数据
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 查询模型信息
- model = session.query(Models).filter_by(ModelID=model_id).first()
- if not model:
- return jsonify({'error': '未找到指定模型'}), 404
-
- # 加载模型
- with open(model.ModelFilePath, 'rb') as f:
- ML_model = pickle.load(f)
-
- # 根据数据类型加载测试数据
- if model.Data_type == 'reflux':
- X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
- Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
- elif model.Data_type == 'reduce':
- X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
- Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
- else:
- return jsonify({'error': '不支持的数据类型'}), 400
-
- # 获取预测值
- y_pred = ML_model.predict(X_test)
-
- # 生成散点图数据
- scatter_data = [
- [float(true), float(pred)]
- for true, pred in zip(Y_test.iloc[:, 0], y_pred)
- ]
-
- # 计算R²分数
- r2 = r2_score(Y_test, y_pred)
-
- # 获取数据范围,用于绘制对角线
- y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
- y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
-
- return jsonify({
- 'scatter_data': scatter_data,
- 'r2_score': float(r2),
- 'y_range': [float(y_min), float(y_max)],
- 'model_name': model.Model_name,
- 'model_type': model.Model_type
- }), 200
-
- except Exception as e:
- logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
- return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
-
- finally:
- session.close()
|