1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057 |
- import sqlite3
- from flask import current_app
- from werkzeug.security import generate_password_hash, check_password_hash
- from flask import Blueprint, request, jsonify, current_app as app
- from .model import predict, train_and_save_model, calculate_model_score
- import pandas as pd
- from . import db # 从 app 包导入 db 实例
- from sqlalchemy.engine.reflection import Inspector
- from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
- import os
- from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
- clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
- predict_to_Q, Q_to_t_ha
- from sqlalchemy.orm import sessionmaker
- import logging
- from sqlalchemy import text, select, MetaData, Table, func
- from .tasks import train_model_task
- # 配置日志
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- # 创建蓝图 (Blueprint),用于分离路由
- bp = Blueprint('routes', __name__)
- # 密码加密
- def hash_password(password):
- return generate_password_hash(password)
- def get_db():
- """ 获取数据库连接 """
- return sqlite3.connect(app.config['DATABASE'])
- # 添加一个新的辅助函数来检查数据集大小并触发训练
- def check_and_trigger_training(session, dataset_type, dataset_df):
- """
- 检查当前数据集大小是否跨越新的阈值点并触发训练
-
- Args:
- session: 数据库会话
- dataset_type: 数据集类型 ('reduce' 或 'reflux')
- dataset_df: 数据集 DataFrame
-
- Returns:
- tuple: (是否触发训练, 任务ID)
- """
- try:
- # 根据数据集类型选择表
- table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
-
- # 获取当前记录数
- current_count = session.query(func.count()).select_from(table).scalar()
-
- # 获取新增的记录数(从request.files中获取的DataFrame长度)
- new_records = len(dataset_df) # 需要从上层函数传入
-
- # 计算新增数据前的记录数
- previous_count = current_count - new_records
-
- # 设置阈值
- THRESHOLD = current_app.config['THRESHOLD']
-
- # 计算上一个阈值点(基于新增前的数据量)
- last_threshold = previous_count // THRESHOLD * THRESHOLD
- # 计算当前所在阈值点
- current_threshold = current_count // THRESHOLD * THRESHOLD
-
- # 检查是否跨越了新的阈值点
- if current_threshold > last_threshold and current_count >= THRESHOLD:
- # 触发异步训练任务
- task = train_model_task.delay(
- model_type=current_app.config['DEFAULT_MODEL_TYPE'],
- model_name=f'auto_trained_{dataset_type}_{current_threshold}',
- model_description=f'Auto trained model at {current_threshold} records threshold',
- data_type=dataset_type
- )
- return True, task.id
-
- return False, None
-
- except Exception as e:
- logging.error(f"检查并触发训练失败: {str(e)}")
- return False, None
- @bp.route('/upload-dataset', methods=['POST'])
- def upload_dataset():
- try:
- if 'file' not in request.files:
- return jsonify({'error': 'No file part'}), 400
- file = request.files['file']
- if file.filename == '' or not allowed_file(file.filename):
- return jsonify({'error': 'No selected file or invalid file type'}), 400
- dataset_name = request.form.get('dataset_name')
- dataset_description = request.form.get('dataset_description', 'No description provided')
- dataset_type = request.form.get('dataset_type')
- if not dataset_type:
- return jsonify({'error': 'Dataset type is required'}), 400
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- new_dataset = Datasets(
- Dataset_name=dataset_name,
- Dataset_description=dataset_description,
- Row_count=0,
- Status='Datasets_upgraded',
- Dataset_type=dataset_type
- )
- session.add(new_dataset)
- session.commit()
- unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
- upload_folder = current_app.config['UPLOAD_FOLDER']
- file_path = os.path.join(upload_folder, unique_filename)
- file.save(file_path)
- dataset_df = pd.read_excel(file_path)
- new_dataset.Row_count = len(dataset_df)
- new_dataset.Status = 'excel_file_saved success'
- session.commit()
- # 处理列名
- dataset_df = clean_column_names(dataset_df)
- dataset_df = rename_columns_for_model(dataset_df, dataset_type)
- column_types = infer_column_types(dataset_df)
- dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
- insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
- # 根据 dataset_type 决定插入到哪个已有表
- if dataset_type == 'reduce':
- insert_data_into_existing_table(session, dataset_df, CurrentReduce)
- elif dataset_type == 'reflux':
- insert_data_into_existing_table(session, dataset_df, CurrentReflux)
- session.commit()
- # 在完成数据插入后,检查是否需要触发训练
- training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
- response_data = {
- 'message': f'Dataset {dataset_name} uploaded successfully!',
- 'dataset_id': new_dataset.Dataset_ID,
- 'filename': unique_filename,
- 'training_triggered': training_triggered
- }
-
- if training_triggered:
- response_data['task_id'] = task_id
- response_data['message'] += ' Auto-training has been triggered.'
- return jsonify(response_data), 201
- except Exception as e:
- session.rollback()
- logging.error('Failed to process the dataset upload:', exc_info=True)
- return jsonify({'error': str(e)}), 500
- finally:
- session.close()
- @bp.route('/train-and-save-model', methods=['POST'])
- def train_and_save_model_endpoint():
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- data = request.get_json()
- # 从请求中解析参数
- model_type = data.get('model_type')
- model_name = data.get('model_name')
- model_description = data.get('model_description')
- data_type = data.get('data_type')
- dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
- try:
- # 调用训练和保存模型的函数
- result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
-
- model_id = result[1] if result else None
-
- # 计算模型评分
- if model_id:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if model_info:
- score = calculate_model_score(model_info)
- # 更新模型评分
- model_info.Performance_score = score
- session.commit()
- result = {'model_id': model_id, 'model_score': score}
- # 返回成功响应
- return jsonify({
- 'message': 'Model trained and saved successfully',
- 'result': result
- }), 200
- except Exception as e:
- session.rollback()
- logging.error('Failed to process the model training:', exc_info=True)
- return jsonify({
- 'error': 'Failed to train and save model',
- 'message': str(e)
- }), 500
- finally:
- session.close()
- @bp.route('/predict', methods=['POST'])
- def predict_route():
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- data = request.get_json()
- model_id = data.get('model_id') # 提取模型名称
- parameters = data.get('parameters', {}) # 提取所有变量
- # 根据model_id获取模型Data_type
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if not model_info:
- return jsonify({'error': 'Model not found'}), 404
- data_type = model_info.Data_type
- input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
- # 如果为reduce,则不需要传入target_ph
- if data_type == 'reduce':
- # 获取传入的init_ph、target_ph参数
- init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
- target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
- # 从输入数据中删除'target_pH'列
- input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
- input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
- predictions = predict(session, input_data_rename, model_id) # 调用预测函数
- if data_type == 'reduce':
- predictions = predictions[0]
- # 将预测结果转换为Q
- Q = predict_to_Q(predictions, init_ph, target_ph)
- predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
- print(predictions)
- return jsonify({'result': predictions}), 200
- except Exception as e:
- logging.error('Failed to predict:', exc_info=True)
- return jsonify({'error': str(e)}), 400
- # 为指定模型计算评分Performance_score,需要提供model_id
- @bp.route('/score-model/<int:model_id>', methods=['POST'])
- def score_model(model_id):
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if not model_info:
- return jsonify({'error': 'Model not found'}), 404
- # 计算模型评分
- score = calculate_model_score(model_info)
- # 更新模型记录中的评分
- model_info.Performance_score = score
- session.commit()
- return jsonify({'message': 'Model scored successfully', 'score': score}), 200
- except Exception as e:
- logging.error('Failed to process the dataset upload:', exc_info=True)
- return jsonify({'error': str(e)}), 400
- finally:
- session.close()
- @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
- def delete_dataset_endpoint(dataset_id):
- """
- 删除数据集的API接口
-
- @param dataset_id: 要删除的数据集ID
- @return: JSON响应
- """
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- try:
- # 查询数据集
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if not dataset:
- return jsonify({'error': '未找到数据集'}), 404
- # 检查是否有模型使用了该数据集
- models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
- if models_using_dataset:
- models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
- return jsonify({
- 'error': '无法删除数据集,因为以下模型正在使用它',
- 'models': models_info
- }), 400
- # 删除Excel文件
- filename = f"dataset_{dataset.Dataset_ID}.xlsx"
- file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
- if os.path.exists(file_path):
- try:
- os.remove(file_path)
- except OSError as e:
- logger.error(f'删除文件失败: {str(e)}')
- return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
- # 删除数据表
- table_name = f"dataset_{dataset.Dataset_ID}"
- session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
- # 删除数据集记录
- session.delete(dataset)
- session.commit()
- return jsonify({
- 'message': '数据集删除成功',
- 'deleted_files': [filename]
- }), 200
- except Exception as e:
- session.rollback()
- logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
- return jsonify({'error': str(e)}), 500
- finally:
- session.close()
- @bp.route('/tables', methods=['GET'])
- def list_tables():
- engine = db.engine # 使用 db 实例的 engine
- inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
- table_names = inspector.get_table_names() # 获取所有表名
- return jsonify(table_names) # 以 JSON 形式返回表名列表
- @bp.route('/models/<int:model_id>', methods=['GET'])
- def get_model(model_id):
- try:
- model = Models.query.filter_by(ModelID=model_id).first()
- if model:
- return jsonify({
- 'ModelID': model.ModelID,
- 'ModelName': model.ModelName,
- 'ModelType': model.ModelType,
- 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description
- })
- else:
- return jsonify({'message': 'Model not found'}), 404
- except Exception as e:
- return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
- @bp.route('/models', methods=['GET'])
- def get_all_models():
- try:
- models = Models.query.all() # 获取所有模型数据
- if models:
- result = [
- {
- 'ModelID': model.ModelID,
- 'ModelName': model.ModelName,
- 'ModelType': model.ModelType,
- 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description
- }
- for model in models
- ]
- return jsonify(result)
- else:
- return jsonify({'message': 'No models found'}), 404
- except Exception as e:
- return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
- @bp.route('/model-parameters', methods=['GET'])
- def get_all_model_parameters():
- try:
- parameters = ModelParameters.query.all() # 获取所有参数数据
- if parameters:
- result = [
- {
- 'ParamID': param.ParamID,
- 'ModelID': param.ModelID,
- 'ParamName': param.ParamName,
- 'ParamValue': param.ParamValue
- }
- for param in parameters
- ]
- return jsonify(result)
- else:
- return jsonify({'message': 'No parameters found'}), 404
- except Exception as e:
- return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
- @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
- def get_model_parameters(model_id):
- try:
- model = Models.query.filter_by(ModelID=model_id).first()
- if model:
- # 获取该模型的所有参数
- parameters = [
- {
- 'ParamID': param.ParamID,
- 'ParamName': param.ParamName,
- 'ParamValue': param.ParamValue
- }
- for param in model.parameters
- ]
-
- # 返回模型参数信息
- return jsonify({
- 'ModelID': model.ModelID,
- 'ModelName': model.ModelName,
- 'ModelType': model.ModelType,
- 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
- 'Description': model.Description,
- 'Parameters': parameters
- })
- else:
- return jsonify({'message': 'Model not found'}), 404
- except Exception as e:
- return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
- # 定义添加数据库记录的 API 接口
- @bp.route('/add_item', methods=['POST'])
- def add_item():
- """
- 接收 JSON 格式的请求体,包含表名和要插入的数据。
- 尝试将数据插入到指定的表中。
- :return:
- """
- try:
- # 确保请求体是JSON格式
- data = request.get_json()
- if not data:
- raise ValueError("No JSON data provided")
- table_name = data.get('table')
- item_data = data.get('item')
- if not table_name or not item_data:
- return jsonify({'error': 'Missing table name or item data'}), 400
- cur = db.cursor()
- # 动态构建 SQL 语句
- columns = ', '.join(item_data.keys())
- placeholders = ', '.join(['?'] * len(item_data))
- sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
- cur.execute(sql, tuple(item_data.values()))
- db.commit()
- # 返回更详细的成功响应
- return jsonify({'success': True, 'message': 'Item added successfully'}), 201
- except ValueError as e:
- return jsonify({'error': str(e)}), 400
- except KeyError as e:
- return jsonify({'error': f'Missing data field: {e}'}), 400
- except sqlite3.IntegrityError as e:
- # 处理例如唯一性约束违反等数据库完整性错误
- return jsonify({'error': 'Database integrity error', 'details': str(e)}), 409
- except sqlite3.Error as e:
- # 处理其他数据库错误
- return jsonify({'error': 'Database error', 'details': str(e)}), 500
- finally:
- db.close()
- # 定义删除数据库记录的 API 接口
- @bp.route('/delete_item', methods=['POST'])
- def delete_item():
- data = request.get_json()
- table_name = data.get('table')
- condition = data.get('condition')
- # 检查表名和条件是否提供
- if not table_name or not condition:
- return jsonify({
- "success": False,
- "message": "缺少表名或条件参数"
- }), 400
- # 尝试从条件字符串中分离键和值
- try:
- key, value = condition.split('=')
- except ValueError:
- return jsonify({
- "success": False,
- "message": "条件格式错误,应为 'key=value'"
- }), 400
- cur = db.cursor()
- try:
- # 执行删除操作
- cur.execute(f"DELETE FROM {table_name} WHERE {key} = ?", (value,))
- db.commit()
- # 如果没有错误发生,返回成功响应
- return jsonify({
- "success": True,
- "message": "记录删除成功"
- }), 200
- except sqlite3.Error as e:
- # 发生错误,回滚事务
- db.rollback()
- # 返回失败响应,并包含错误信息
- return jsonify({
- "success": False,
- "message": f"删除失败: {e}"
- }), 400
- # 定义修改数据库记录的 API 接口
- @bp.route('/update_item', methods=['PUT'])
- def update_record():
- data = request.get_json()
- # 检查必要的数据是否提供
- if not data or 'table' not in data or 'item' not in data:
- return jsonify({
- "success": False,
- "message": "请求数据不完整"
- }), 400
- table_name = data['table']
- item = data['item']
- # 假设 item 的第一个元素是 ID
- if not item or next(iter(item.keys())) is None:
- return jsonify({
- "success": False,
- "message": "记录数据为空"
- }), 400
- # 获取 ID 和其他字段值
- id_key = next(iter(item.keys()))
- record_id = item[id_key]
- updates = {key: value for key, value in item.items() if key != id_key} # 排除 ID
- cur = db.cursor()
- try:
- record_id = int(record_id) # 确保 ID 是整数
- except ValueError:
- return jsonify({
- "success": False,
- "message": "ID 必须是整数"
- }), 400
- # 准备参数列表,包括更新的值和 ID
- parameters = list(updates.values()) + [record_id]
- # 执行更新操作
- set_clause = ','.join([f"{k} = ?" for k in updates.keys()])
- sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = ?"
- try:
- cur.execute(sql, parameters)
- db.commit()
- if cur.rowcount == 0:
- return jsonify({
- "success": False,
- "message": "未找到要更新的记录"
- }), 404
- return jsonify({
- "success": True,
- "message": "数据更新成功"
- }), 200
- except sqlite3.Error as e:
- db.rollback()
- return jsonify({
- "success": False,
- "message": f"更新失败: {e}"
- }), 400
- # 定义查询数据库记录的 API 接口
- @bp.route('/search/record', methods=['GET'])
- def sql_search():
- """
- 接收 JSON 格式的请求体,包含表名和要查询的 ID。
- 尝试查询指定 ID 的记录并返回结果。
- :return:
- """
- try:
- data = request.get_json()
- # 表名
- sql_table = data['table']
- # 要搜索的 ID
- Id = data['id']
- # 连接到数据库
- cur = db.cursor()
- # 构造查询语句
- sql = f"SELECT * FROM {sql_table} WHERE id = ?"
- # 执行查询
- cur.execute(sql, (Id,))
- # 获取查询结果
- rows = cur.fetchall()
- column_names = [desc[0] for desc in cur.description]
- # 检查是否有结果
- if not rows:
- return jsonify({'error': '未查找到对应数据。'}), 400
- # 构造响应数据
- results = []
- for row in rows:
- result = {column_names[i]: row[i] for i in range(len(row))}
- results.append(result)
- # 关闭游标和数据库连接
- cur.close()
- db.close()
- # 返回 JSON 响应
- return jsonify(results), 200
- except sqlite3.Error as e:
- # 如果发生数据库错误,返回错误信息
- return jsonify({'error': str(e)}), 400
- except KeyError as e:
- # 如果请求数据中缺少必要的键,返回错误信息
- return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
- # 定义提供数据库列表,用于展示表格的 API 接口
- @bp.route('/table', methods=['POST'])
- def get_table():
- data = request.get_json()
- table_name = data.get('table')
- if not table_name:
- return jsonify({'error': '需要表名'}), 400
- try:
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
- # 动态获取表的元数据
- metadata = MetaData()
- table = Table(table_name, metadata, autoload_with=db.engine)
- # 从数据库中查询所有记录
- query = select(table)
- result = session.execute(query).fetchall()
- # 将结果转换为列表字典形式
- rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
- # 获取列名
- headers = [column.name for column in table.columns]
- return jsonify(rows=rows, headers=headers), 200
- except Exception as e:
- return jsonify({'error': str(e)}), 400
- finally:
- # 关闭 session
- session.close()
- @bp.route('/train-model-async', methods=['POST'])
- def train_model_async():
- """
- 异步训练模型的API接口
- """
- try:
- data = request.get_json()
-
- # 从请求中获取参数
- model_type = data.get('model_type')
- model_name = data.get('model_name')
- model_description = data.get('model_description')
- data_type = data.get('data_type')
- dataset_id = data.get('dataset_id', None)
-
- # 验证必要参数
- if not all([model_type, model_name, data_type]):
- return jsonify({
- 'error': 'Missing required parameters'
- }), 400
-
- # 启动异步任务
- task = train_model_task.delay(
- model_type,
- model_name,
- model_description,
- data_type,
- dataset_id
- )
-
- # 返回任务ID
- return jsonify({
- 'task_id': task.id,
- 'message': 'Model training started'
- }), 202
-
- except Exception as e:
- logging.error('Failed to start async training task:', exc_info=True)
- return jsonify({
- 'error': str(e)
- }), 500
- @bp.route('/task-status/<task_id>', methods=['GET'])
- def get_task_status(task_id):
- """
- 获取异步任务状态的API接口
- """
- try:
- task = train_model_task.AsyncResult(task_id)
-
- if task.state == 'PENDING':
- response = {
- 'state': task.state,
- 'status': 'Task is waiting for execution'
- }
- elif task.state == 'FAILURE':
- response = {
- 'state': task.state,
- 'status': 'Task failed',
- 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
- }
- elif task.state == 'SUCCESS':
- response = {
- 'state': task.state,
- 'status': 'Task completed successfully',
- 'result': task.get()
- }
- else:
- response = {
- 'state': task.state,
- 'status': 'Task is in progress'
- }
-
- return jsonify(response), 200
-
- except Exception as e:
- return jsonify({
- 'error': str(e)
- }), 500
- @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
- def delete_model(model_id):
- """
- 删除指定模型的API接口
-
- @param model_id: 要删除的模型ID
- @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
- @return: JSON响应
- """
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 查询模型信息
- model = session.query(Models).filter_by(ModelID=model_id).first()
- if not model:
- return jsonify({'error': '未找到指定模型'}), 404
-
- dataset_id = model.DatasetID
-
- # 1. 先删除模型记录
- session.delete(model)
- session.commit()
-
- # 2. 删除模型文件
- model_file = f"rf_model_{model_id}.pkl"
- model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
- if os.path.exists(model_path):
- try:
- os.remove(model_path)
- except OSError as e:
- # 如果删除文件失败,回滚数据库操作
- session.rollback()
- logger.error(f'删除模型文件失败: {str(e)}')
- return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
- # 3. 如果需要删除关联的数据集
- delete_dataset = request.args.get('delete_dataset', 'false').lower() == 'true'
- if delete_dataset and dataset_id:
- try:
- dataset_response = delete_dataset_endpoint(dataset_id)
- if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
- # 如果删除数据集失败,回滚之前的操作
- session.rollback()
- return jsonify({
- 'error': '删除关联数据集失败',
- 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
- }), 500
- except Exception as e:
- session.rollback()
- logger.error(f'删除关联数据集失败: {str(e)}')
- return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
- response_data = {
- 'message': '模型删除成功',
- 'deleted_files': [model_file]
- }
-
- if delete_dataset:
- response_data['dataset_info'] = {
- 'dataset_id': dataset_id,
- 'message': '关联数据集已删除'
- }
-
- return jsonify(response_data), 200
-
- except Exception as e:
- session.rollback()
- logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
- return jsonify({'error': str(e)}), 500
- finally:
- session.close()
- # 添加一个新的API端点来清空指定数据集
- @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
- def clear_dataset(data_type):
- """
- 清空指定类型的数据集并递增计数
- @param data_type: 数据集类型 ('reduce' 或 'reflux')
- @return: JSON响应
- """
- # 创建 sessionmaker 实例
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 根据数据集类型选择表
- if data_type == 'reduce':
- table = CurrentReduce
- table_name = 'current_reduce'
- elif data_type == 'reflux':
- table = CurrentReflux
- table_name = 'current_reflux'
- else:
- return jsonify({'error': '无效的数据集类型'}), 400
-
- # 清空表内容
- session.query(table).delete()
-
- # 重置自增主键计数器
- session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
-
- session.commit()
-
- return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
-
- except Exception as e:
- session.rollback()
- return jsonify({'error': str(e)}), 500
-
- finally:
- session.close()
- @bp.route('/login', methods=['POST'])
- def login_user():
- # 获取前端传来的数据
- data = request.get_json()
- name = data.get('name') # 用户名
- password = data.get('password') # 密码
- logger.info(f"Login request received: name={name}")
- # 检查用户名和密码是否为空
- if not name or not password:
- logger.warning("用户名和密码不能为空")
- return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
- try:
- # 查询数据库验证用户名
- query = "SELECT * FROM users WHERE name = :name"
- conn = get_db()
- user = conn.execute(query, {"name": name}).fetchone()
- if not user:
- logger.warning(f"用户名 '{name}' 不存在")
- return jsonify({"success": False, "message": "用户名不存在"}), 400
- # 获取数据库中存储的密码(假设密码是哈希存储的)
- stored_password = user[2] # 假设密码存储在数据库的第三列
- user_id = user[0] # 假设 id 存储在数据库的第一列
- # 校验密码是否正确
- if check_password_hash(stored_password, password):
- logger.info(f"User '{name}' logged in successfully.")
- return jsonify({
- "success": True,
- "message": "登录成功",
- "userId": user_id # 返回用户 ID
- })
- else:
- logger.warning(f"Invalid password for user '{name}'")
- return jsonify({"success": False, "message": "用户名或密码错误"}), 400
- except Exception as e:
- # 记录错误日志并返回错误信息
- logger.error(f"Error during login: {e}", exc_info=True)
- return jsonify({"success": False, "message": "登录失败"}), 500
- # 更新用户信息接口
- @bp.route('/update_user', methods=['POST'])
- def update_user():
- # 获取前端传来的数据
- data = request.get_json()
- # 打印收到的请求数据
- app.logger.info(f"Received data: {data}")
- user_id = data.get('userId') # 用户ID
- name = data.get('name') # 用户名
- old_password = data.get('oldPassword') # 旧密码
- new_password = data.get('newPassword') # 新密码
- logger.info(f"Update request received: user_id={user_id}, name={name}")
- # 校验传入的用户名和密码是否为空
- if not name or not old_password:
- logger.warning("用户名和旧密码不能为空")
- return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
- # 新密码和旧密码不能相同
- if new_password and old_password == new_password:
- logger.warning(f"新密码与旧密码相同:{name}")
- return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
- try:
- # 查询数据库验证用户ID
- query = "SELECT * FROM users WHERE id = :user_id"
- conn = get_db()
- user = conn.execute(query, {"user_id": user_id}).fetchone()
- if not user:
- logger.warning(f"用户ID '{user_id}' 不存在")
- return jsonify({"success": False, "message": "用户不存在"}), 400
- # 获取数据库中存储的密码(假设密码是哈希存储的)
- stored_password = user[2] # 假设密码存储在数据库的第三列
- # 校验旧密码是否正确
- if not check_password_hash(stored_password, old_password):
- logger.warning(f"旧密码错误:{name}")
- return jsonify({"success": False, "message": "旧密码错误"}), 400
- # 如果新密码非空,则更新新密码
- if new_password:
- hashed_new_password = hash_password(new_password)
- update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
- conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
- conn.commit()
- logger.info(f"User ID '{user_id}' password updated successfully.")
- # 如果用户名发生更改,则更新用户名
- if name != user[1]:
- update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
- conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
- conn.commit()
- logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
- return jsonify({"success": True, "message": "用户信息更新成功"})
- except Exception as e:
- # 记录错误日志并返回错误信息
- logger.error(f"Error updating user: {e}", exc_info=True)
- return jsonify({"success": False, "message": "更新失败"}), 500
- # 注册用户
- @bp.route('/register', methods=['POST'])
- def register_user():
- # 获取前端传来的数据
- data = request.get_json()
- name = data.get('name') # 用户名
- password = data.get('password') # 密码
- logger.info(f"Register request received: name={name}")
- # 检查用户名和密码是否为空
- if not name or not password:
- logger.warning("用户名和密码不能为空")
- return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
- # 动态获取数据库表的列名
- columns = get_column_names('users')
- logger.info(f"Database columns for 'users' table: {columns}")
- # 检查前端传来的数据是否包含数据库表中所有的必填字段
- for column in ['name', 'password']:
- if column not in columns:
- logger.error(f"缺少必填字段:{column}")
- return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
- # 对密码进行哈希处理
- hashed_password = hash_password(password)
- logger.info(f"Password hashed for user: {name}")
- # 插入到数据库
- try:
- # 检查用户是否已经存在
- query = "SELECT * FROM users WHERE name = :name"
- conn = get_db()
- user = conn.execute(query, {"name": name}).fetchone()
- if user:
- logger.warning(f"用户名 '{name}' 已存在")
- return jsonify({"success": False, "message": "用户名已存在"}), 400
- # 向数据库插入数据
- query = "INSERT INTO users (name, password) VALUES (:name, :password)"
- conn.execute(query, {"name": name, "password": hashed_password})
- conn.commit()
- logger.info(f"User '{name}' registered successfully.")
- return jsonify({"success": True, "message": "注册成功"})
- except Exception as e:
- # 记录错误日志并返回错误信息
- logger.error(f"Error registering user: {e}", exc_info=True)
- return jsonify({"success": False, "message": "注册失败"}), 500
- def get_column_names(table_name):
- """
- 动态获取数据库表的列名。
- """
- try:
- conn = get_db()
- query = f"PRAGMA table_info({table_name});"
- result = conn.execute(query).fetchall()
- conn.close()
- return [row[1] for row in result] # 第二列是列名
- except Exception as e:
- logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
- return []
|