routes.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. import pickle
  2. import sqlite3
  3. from flask import Blueprint, request, jsonify,current_app
  4. from werkzeug.security import generate_password_hash
  5. from sklearn.metrics import r2_score
  6. from .model import predict, train_and_save_model, calculate_model_score
  7. import pandas as pd
  8. from . import db # 从 app 包导入 db 实例
  9. from sqlalchemy.engine.reflection import Inspector
  10. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  11. import os
  12. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  13. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  14. predict_to_Q, Q_to_t_ha, create_kriging
  15. from sqlalchemy.orm import sessionmaker
  16. import logging
  17. from sqlalchemy import text, func, MetaData, Table, select
  18. from .tasks import train_model_task
  19. from datetime import datetime
  20. # 配置日志
  21. logging.basicConfig(level=logging.DEBUG)
  22. logger = logging.getLogger(__name__)
  23. # 创建蓝图 (Blueprint),用于分离路由
  24. bp = Blueprint('routes', __name__)
  25. # 封装数据库连接函数
  26. def get_db_connection():
  27. return sqlite3.connect('software_intro.db')
  28. # 密码加密
  29. def hash_password(password):
  30. return generate_password_hash(password)
  31. def get_db():
  32. """ 获取数据库连接 """
  33. return sqlite3.connect(current_app.config['DATABASE'])
  34. # 添加一个新的辅助函数来检查数据集大小并触发训练
  35. def check_and_trigger_training(session, dataset_type, dataset_df):
  36. """
  37. 检查当前数据集大小是否跨越新的阈值点并触发训练
  38. Args:
  39. session: 数据库会话
  40. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  41. dataset_df: 数据集 DataFrame
  42. Returns:
  43. tuple: (是否触发训练, 任务ID)
  44. """
  45. try:
  46. # 根据数据集类型选择表
  47. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  48. # 获取当前记录数
  49. current_count = session.query(func.count()).select_from(table).scalar()
  50. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  51. new_records = len(dataset_df) # 需要从上层函数传入
  52. # 计算新增数据前的记录数
  53. previous_count = current_count - new_records
  54. # 设置阈值
  55. THRESHOLD = current_app.config['THRESHOLD']
  56. # 计算上一个阈值点(基于新增前的数据量)
  57. last_threshold = previous_count // THRESHOLD * THRESHOLD
  58. # 计算当前所在阈值点
  59. current_threshold = current_count // THRESHOLD * THRESHOLD
  60. # 检查是否跨越了新的阈值点
  61. if current_threshold > last_threshold and current_count >= THRESHOLD:
  62. # 触发异步训练任务
  63. task = train_model_task.delay(
  64. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  65. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  66. model_description=f'Auto trained model at {current_threshold} records threshold',
  67. data_type=dataset_type
  68. )
  69. return True, task.id
  70. return False, None
  71. except Exception as e:
  72. logging.error(f"检查并触发训练失败: {str(e)}")
  73. return False, None
  74. @bp.route('/upload-dataset', methods=['POST'])
  75. def upload_dataset():
  76. # 创建 session
  77. Session = sessionmaker(bind=db.engine)
  78. session = Session()
  79. try:
  80. if 'file' not in request.files:
  81. return jsonify({'error': 'No file part'}), 400
  82. file = request.files['file']
  83. if file.filename == '' or not allowed_file(file.filename):
  84. return jsonify({'error': 'No selected file or invalid file type'}), 400
  85. dataset_name = request.form.get('dataset_name')
  86. dataset_description = request.form.get('dataset_description', 'No description provided')
  87. dataset_type = request.form.get('dataset_type')
  88. if not dataset_type:
  89. return jsonify({'error': 'Dataset type is required'}), 400
  90. new_dataset = Datasets(
  91. Dataset_name=dataset_name,
  92. Dataset_description=dataset_description,
  93. Row_count=0,
  94. Status='Datasets_upgraded',
  95. Dataset_type=dataset_type,
  96. Uploaded_at=datetime.now()
  97. )
  98. session.add(new_dataset)
  99. session.commit()
  100. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  101. upload_folder = current_app.config['UPLOAD_FOLDER']
  102. file_path = os.path.join(upload_folder, unique_filename)
  103. file.save(file_path)
  104. dataset_df = pd.read_excel(file_path)
  105. new_dataset.Row_count = len(dataset_df)
  106. new_dataset.Status = 'excel_file_saved success'
  107. session.commit()
  108. # 处理列名
  109. dataset_df = clean_column_names(dataset_df)
  110. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  111. column_types = infer_column_types(dataset_df)
  112. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  113. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  114. # 根据 dataset_type 决定插入到哪个已有表
  115. if dataset_type == 'reduce':
  116. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  117. elif dataset_type == 'reflux':
  118. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  119. session.commit()
  120. # 在完成数据插入后,检查是否需要触发训练
  121. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  122. response_data = {
  123. 'message': f'Dataset {dataset_name} uploaded successfully!',
  124. 'dataset_id': new_dataset.Dataset_ID,
  125. 'filename': unique_filename,
  126. 'training_triggered': training_triggered
  127. }
  128. if training_triggered:
  129. response_data['task_id'] = task_id
  130. response_data['message'] += ' Auto-training has been triggered.'
  131. return jsonify(response_data), 201
  132. except Exception as e:
  133. session.rollback()
  134. logging.error('Failed to process the dataset upload:', exc_info=True)
  135. return jsonify({'error': str(e)}), 500
  136. finally:
  137. # 确保 session 总是被关闭
  138. if session:
  139. session.close()
  140. @bp.route('/train-and-save-model', methods=['POST'])
  141. def train_and_save_model_endpoint():
  142. # 创建 sessionmaker 实例
  143. Session = sessionmaker(bind=db.engine)
  144. session = Session()
  145. data = request.get_json()
  146. # 从请求中解析参数
  147. model_type = data.get('model_type')
  148. model_name = data.get('model_name')
  149. model_description = data.get('model_description')
  150. data_type = data.get('data_type')
  151. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  152. try:
  153. # 调用训练和保存模型的函数
  154. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  155. model_id = result[1] if result else None
  156. # 计算模型评分
  157. if model_id:
  158. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  159. if model_info:
  160. score = calculate_model_score(model_info)
  161. # 更新模型评分
  162. model_info.Performance_score = score
  163. session.commit()
  164. result = {'model_id': model_id, 'model_score': score}
  165. # 返回成功响应
  166. return jsonify({
  167. 'message': 'Model trained and saved successfully',
  168. 'result': result
  169. }), 200
  170. except Exception as e:
  171. session.rollback()
  172. logging.error('Failed to process the model training:', exc_info=True)
  173. return jsonify({
  174. 'error': 'Failed to train and save model',
  175. 'message': str(e)
  176. }), 500
  177. finally:
  178. session.close()
  179. @bp.route('/predict', methods=['POST'])
  180. def predict_route():
  181. # 创建 sessionmaker 实例
  182. Session = sessionmaker(bind=db.engine)
  183. session = Session()
  184. try:
  185. data = request.get_json()
  186. model_id = data.get('model_id') # 提取模型名称
  187. parameters = data.get('parameters', {}) # 提取所有变量
  188. # 根据model_id获取模型Data_type
  189. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  190. if not model_info:
  191. return jsonify({'error': 'Model not found'}), 404
  192. data_type = model_info.Data_type
  193. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  194. # 如果为reduce,则不需要传入target_ph
  195. if data_type == 'reduce':
  196. # 获取传入的init_ph、target_ph参数
  197. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  198. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  199. # 从输入数据中删除'target_pH'列
  200. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  201. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  202. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  203. if data_type == 'reduce':
  204. predictions = predictions[0]
  205. # 将预测结果转换为Q
  206. Q = predict_to_Q(predictions, init_ph, target_ph)
  207. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  208. print(predictions)
  209. return jsonify({'result': predictions}), 200
  210. except Exception as e:
  211. logging.error('Failed to predict:', exc_info=True)
  212. return jsonify({'error': str(e)}), 400
  213. # 为指定模型计算评分Performance_score,需要提供model_id
  214. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  215. def score_model(model_id):
  216. # 创建 sessionmaker 实例
  217. Session = sessionmaker(bind=db.engine)
  218. session = Session()
  219. try:
  220. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  221. if not model_info:
  222. return jsonify({'error': 'Model not found'}), 404
  223. # 计算模型评分
  224. score = calculate_model_score(model_info)
  225. # 更新模型记录中的评分
  226. model_info.Performance_score = score
  227. session.commit()
  228. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  229. except Exception as e:
  230. logging.error('Failed to process the dataset upload:', exc_info=True)
  231. return jsonify({'error': str(e)}), 400
  232. finally:
  233. session.close()
  234. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  235. def delete_dataset_endpoint(dataset_id):
  236. """
  237. 删除数据集的API接口
  238. @param dataset_id: 要删除的数据集ID
  239. @return: JSON响应
  240. """
  241. # 创建 sessionmaker 实例
  242. Session = sessionmaker(bind=db.engine)
  243. session = Session()
  244. try:
  245. # 查询数据集
  246. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  247. if not dataset:
  248. return jsonify({'error': '未找到数据集'}), 404
  249. # 检查是否有模型使用了该数据集
  250. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  251. if models_using_dataset:
  252. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  253. return jsonify({
  254. 'error': '无法删除数据集,因为以下模型正在使用它',
  255. 'models': models_info
  256. }), 400
  257. # 删除Excel文件
  258. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  259. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  260. if os.path.exists(file_path):
  261. try:
  262. os.remove(file_path)
  263. except OSError as e:
  264. logger.error(f'删除文件失败: {str(e)}')
  265. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  266. # 删除数据表
  267. table_name = f"dataset_{dataset.Dataset_ID}"
  268. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  269. # 删除数据集记录
  270. session.delete(dataset)
  271. session.commit()
  272. return jsonify({
  273. 'message': '数据集删除成功',
  274. 'deleted_files': [filename]
  275. }), 200
  276. except Exception as e:
  277. session.rollback()
  278. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  279. return jsonify({'error': str(e)}), 500
  280. finally:
  281. session.close()
  282. @bp.route('/tables', methods=['GET'])
  283. def list_tables():
  284. engine = db.engine # 使用 db 实例的 engine
  285. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  286. table_names = inspector.get_table_names() # 获取所有表名
  287. return jsonify(table_names) # 以 JSON 形式返回表名列表
  288. @bp.route('/models/<int:model_id>', methods=['GET'])
  289. def get_model(model_id):
  290. """
  291. 获取单个模型信息的API接口
  292. @param model_id: 模型ID
  293. @return: JSON响应
  294. """
  295. Session = sessionmaker(bind=db.engine)
  296. session = Session()
  297. try:
  298. model = session.query(Models).filter_by(ModelID=model_id).first()
  299. if model:
  300. return jsonify({
  301. 'ModelID': model.ModelID,
  302. 'Model_name': model.Model_name,
  303. 'Model_type': model.Model_type,
  304. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  305. 'Description': model.Description,
  306. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  307. 'MAE': float(model.MAE) if model.MAE else None,
  308. 'RMSE': float(model.RMSE) if model.RMSE else None,
  309. 'Data_type': model.Data_type
  310. })
  311. else:
  312. return jsonify({'message': '未找到模型'}), 404
  313. except Exception as e:
  314. logger.error(f'获取模型信息失败: {str(e)}')
  315. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  316. finally:
  317. session.close()
  318. @bp.route('/model-parameters', methods=['GET'])
  319. def get_all_model_parameters():
  320. """
  321. 获取所有模型参数的API接口
  322. @return: JSON响应
  323. """
  324. Session = sessionmaker(bind=db.engine)
  325. session = Session()
  326. try:
  327. parameters = session.query(ModelParameters).all()
  328. if parameters:
  329. result = [
  330. {
  331. 'ParamID': param.ParamID,
  332. 'ModelID': param.ModelID,
  333. 'ParamName': param.ParamName,
  334. 'ParamValue': param.ParamValue
  335. }
  336. for param in parameters
  337. ]
  338. return jsonify(result)
  339. else:
  340. return jsonify({'message': '未找到任何参数'}), 404
  341. except Exception as e:
  342. logger.error(f'获取所有模型参数失败: {str(e)}')
  343. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  344. finally:
  345. session.close()
  346. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  347. def get_model_parameters(model_id):
  348. try:
  349. model = Models.query.filter_by(ModelID=model_id).first()
  350. if model:
  351. # 获取该模型的所有参数
  352. parameters = [
  353. {
  354. 'ParamID': param.ParamID,
  355. 'ParamName': param.ParamName,
  356. 'ParamValue': param.ParamValue
  357. }
  358. for param in model.parameters
  359. ]
  360. # 返回模型参数信息
  361. return jsonify({
  362. 'ModelID': model.ModelID,
  363. 'ModelName': model.ModelName,
  364. 'ModelType': model.ModelType,
  365. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  366. 'Description': model.Description,
  367. 'Parameters': parameters
  368. })
  369. else:
  370. return jsonify({'message': 'Model not found'}), 404
  371. except Exception as e:
  372. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  373. @bp.route('/train-model-async', methods=['POST'])
  374. def train_model_async():
  375. """
  376. 异步训练模型的API接口
  377. """
  378. try:
  379. data = request.get_json()
  380. # 从请求中获取参数
  381. model_type = data.get('model_type')
  382. model_name = data.get('model_name')
  383. model_description = data.get('model_description')
  384. data_type = data.get('data_type')
  385. dataset_id = data.get('dataset_id', None)
  386. # 验证必要参数
  387. if not all([model_type, model_name, data_type]):
  388. return jsonify({
  389. 'error': 'Missing required parameters'
  390. }), 400
  391. # 如果提供了dataset_id,验证数据集是否存在
  392. if dataset_id:
  393. Session = sessionmaker(bind=db.engine)
  394. session = Session()
  395. try:
  396. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  397. if not dataset:
  398. return jsonify({
  399. 'error': f'Dataset with ID {dataset_id} not found'
  400. }), 404
  401. finally:
  402. session.close()
  403. # 启动异步任务
  404. task = train_model_task.delay(
  405. model_type=model_type,
  406. model_name=model_name,
  407. model_description=model_description,
  408. data_type=data_type,
  409. dataset_id=dataset_id
  410. )
  411. # 返回任务ID
  412. return jsonify({
  413. 'task_id': task.id,
  414. 'message': 'Model training started'
  415. }), 202
  416. except Exception as e:
  417. logging.error('Failed to start async training task:', exc_info=True)
  418. return jsonify({
  419. 'error': str(e)
  420. }), 500
  421. @bp.route('/task-status/<task_id>', methods=['GET'])
  422. def get_task_status(task_id):
  423. """
  424. 获取异步任务状态的API接口
  425. """
  426. try:
  427. task = train_model_task.AsyncResult(task_id)
  428. if task.state == 'PENDING':
  429. response = {
  430. 'state': task.state,
  431. 'status': 'Task is waiting for execution'
  432. }
  433. elif task.state == 'FAILURE':
  434. response = {
  435. 'state': task.state,
  436. 'status': 'Task failed',
  437. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  438. }
  439. elif task.state == 'SUCCESS':
  440. response = {
  441. 'state': task.state,
  442. 'status': 'Task completed successfully',
  443. 'result': task.get()
  444. }
  445. else:
  446. response = {
  447. 'state': task.state,
  448. 'status': 'Task is in progress'
  449. }
  450. return jsonify(response), 200
  451. except Exception as e:
  452. return jsonify({
  453. 'error': str(e)
  454. }), 500
  455. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  456. def delete_model_route(model_id):
  457. # 将URL参数转换为布尔值
  458. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  459. # 调用原始函数
  460. return delete_model(model_id, delete_dataset=delete_dataset_param)
  461. def delete_model(model_id, delete_dataset=False):
  462. """
  463. 删除指定模型的API接口
  464. @param model_id: 要删除的模型ID
  465. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  466. @return: JSON响应
  467. """
  468. Session = sessionmaker(bind=db.engine)
  469. session = Session()
  470. try:
  471. # 查询模型信息
  472. model = session.query(Models).filter_by(ModelID=model_id).first()
  473. if not model:
  474. return jsonify({'error': '未找到指定模型'}), 404
  475. dataset_id = model.DatasetID
  476. # 1. 先删除模型记录
  477. session.delete(model)
  478. session.commit()
  479. # 2. 删除模型文件
  480. model_file = f"rf_model_{model_id}.pkl"
  481. model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
  482. if os.path.exists(model_path):
  483. try:
  484. os.remove(model_path)
  485. except OSError as e:
  486. # 如果删除文件失败,回滚数据库操作
  487. session.rollback()
  488. logger.error(f'删除模型文件失败: {str(e)}')
  489. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  490. # 3. 如果需要删除关联的数据集
  491. if delete_dataset and dataset_id:
  492. try:
  493. dataset_response = delete_dataset_endpoint(dataset_id)
  494. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  495. # 如果删除数据集失败,回滚之前的操作
  496. session.rollback()
  497. return jsonify({
  498. 'error': '删除关联数据集失败',
  499. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  500. }), 500
  501. except Exception as e:
  502. session.rollback()
  503. logger.error(f'删除关联数据集失败: {str(e)}')
  504. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  505. response_data = {
  506. 'message': '模型删除成功',
  507. 'deleted_files': [model_file]
  508. }
  509. if delete_dataset:
  510. response_data['dataset_info'] = {
  511. 'dataset_id': dataset_id,
  512. 'message': '关联数据集已删除'
  513. }
  514. return jsonify(response_data), 200
  515. except Exception as e:
  516. session.rollback()
  517. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  518. return jsonify({'error': str(e)}), 500
  519. finally:
  520. session.close()
  521. # 添加一个新的API端点来清空指定数据集
  522. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  523. def clear_dataset(data_type):
  524. """
  525. 清空指定类型的数据集并递增计数
  526. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  527. @return: JSON响应
  528. """
  529. # 创建 sessionmaker 实例
  530. Session = sessionmaker(bind=db.engine)
  531. session = Session()
  532. try:
  533. # 根据数据集类型选择表
  534. if data_type == 'reduce':
  535. table = CurrentReduce
  536. table_name = 'current_reduce'
  537. elif data_type == 'reflux':
  538. table = CurrentReflux
  539. table_name = 'current_reflux'
  540. else:
  541. return jsonify({'error': '无效的数据集类型'}), 400
  542. # 清空表内容
  543. session.query(table).delete()
  544. # 重置自增主键计数器
  545. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  546. session.commit()
  547. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  548. except Exception as e:
  549. session.rollback()
  550. return jsonify({'error': str(e)}), 500
  551. finally:
  552. session.close()
  553. @bp.route('/update-threshold', methods=['POST'])
  554. def update_threshold():
  555. """
  556. 更新训练阈值的API接口
  557. @body_param threshold: 新的阈值值(整数)
  558. @return: JSON响应
  559. """
  560. try:
  561. data = request.get_json()
  562. new_threshold = data.get('threshold')
  563. # 验证新阈值
  564. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  565. return jsonify({
  566. 'error': '无效的阈值值,必须为正数'
  567. }), 400
  568. # 更新当前应用的阈值配置
  569. current_app.config['THRESHOLD'] = int(new_threshold)
  570. return jsonify({
  571. 'success': True,
  572. 'message': f'阈值已更新为 {new_threshold}',
  573. 'new_threshold': new_threshold
  574. })
  575. except Exception as e:
  576. logging.error(f"更新阈值失败: {str(e)}")
  577. return jsonify({
  578. 'error': f'更新阈值失败: {str(e)}'
  579. }), 500
  580. @bp.route('/get-threshold', methods=['GET'])
  581. def get_threshold():
  582. """
  583. 获取当前训练阈值的API接口
  584. @return: JSON响应
  585. """
  586. try:
  587. current_threshold = current_app.config['THRESHOLD']
  588. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  589. return jsonify({
  590. 'current_threshold': current_threshold,
  591. 'default_threshold': default_threshold
  592. })
  593. except Exception as e:
  594. logging.error(f"获取阈值失败: {str(e)}")
  595. return jsonify({
  596. 'error': f'获取阈值失败: {str(e)}'
  597. }), 500
  598. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  599. def set_current_dataset(data_type, dataset_id):
  600. """
  601. 将指定数据集设置为current数据集
  602. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  603. @param dataset_id: 要设置为current的数据集ID
  604. @return: JSON响应
  605. """
  606. Session = sessionmaker(bind=db.engine)
  607. session = Session()
  608. try:
  609. # 验证数据集存在且类型匹配
  610. dataset = session.query(Datasets)\
  611. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  612. .first()
  613. if not dataset:
  614. return jsonify({
  615. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  616. }), 404
  617. # 根据数据类型选择表
  618. if data_type == 'reduce':
  619. table = CurrentReduce
  620. table_name = 'current_reduce'
  621. elif data_type == 'reflux':
  622. table = CurrentReflux
  623. table_name = 'current_reflux'
  624. else:
  625. return jsonify({'error': '无效的数据集类型'}), 400
  626. # 清空current表
  627. session.query(table).delete()
  628. # 重置自增主键计数器
  629. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  630. # 从指定数据集复制数据到current表
  631. dataset_table_name = f"dataset_{dataset_id}"
  632. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  633. session.execute(copy_sql)
  634. session.commit()
  635. return jsonify({
  636. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  637. 'dataset_id': dataset_id,
  638. 'dataset_name': dataset.Dataset_name,
  639. 'row_count': dataset.Row_count
  640. }), 200
  641. except Exception as e:
  642. session.rollback()
  643. logger.error(f'设置current数据集失败: {str(e)}')
  644. return jsonify({'error': str(e)}), 500
  645. finally:
  646. session.close()
  647. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  648. def get_model_history(data_type):
  649. """
  650. 获取模型训练历史数据的API接口
  651. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  652. @return: JSON响应,包含时间序列的模型性能数据
  653. """
  654. Session = sessionmaker(bind=db.engine)
  655. session = Session()
  656. try:
  657. # 查询所有自动生成的数据集,按时间排序
  658. datasets = session.query(Datasets).filter(
  659. Datasets.Dataset_type == data_type,
  660. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  661. ).order_by(Datasets.Uploaded_at).all()
  662. history_data = []
  663. for dataset in datasets:
  664. # 查找对应的自动训练模型
  665. model = session.query(Models).filter(
  666. Models.DatasetID == dataset.Dataset_ID,
  667. Models.Model_name.like(f'auto_trained_{data_type}_%')
  668. ).first()
  669. if model and model.Performance_score is not None:
  670. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  671. created_at = model.Created_at.isoformat() if model.Created_at else None
  672. history_data.append({
  673. 'dataset_id': dataset.Dataset_ID,
  674. 'row_count': dataset.Row_count,
  675. 'model_id': model.ModelID,
  676. 'model_name': model.Model_name,
  677. 'performance_score': float(model.Performance_score),
  678. 'timestamp': created_at
  679. })
  680. # 按时间戳排序
  681. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  682. # 构建返回数据,分离各个指标序列便于前端绘图
  683. response_data = {
  684. 'data_type': data_type,
  685. 'timestamps': [item['timestamp'] for item in history_data],
  686. 'row_counts': [item['row_count'] for item in history_data],
  687. 'performance_scores': [item['performance_score'] for item in history_data],
  688. 'model_details': history_data # 保留完整数据供前端使用
  689. }
  690. return jsonify(response_data), 200
  691. except Exception as e:
  692. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  693. return jsonify({'error': str(e)}), 500
  694. finally:
  695. session.close()
  696. @bp.route('/batch-delete-datasets', methods=['POST'])
  697. def batch_delete_datasets():
  698. """
  699. 批量删除数据集的API接口
  700. @body_param dataset_ids: 要删除的数据集ID列表
  701. @return: JSON响应
  702. """
  703. try:
  704. data = request.get_json()
  705. dataset_ids = data.get('dataset_ids', [])
  706. if not dataset_ids:
  707. return jsonify({'error': '未提供数据集ID列表'}), 400
  708. results = {
  709. 'success': [],
  710. 'failed': [],
  711. 'protected': [] # 被模型使用的数据集
  712. }
  713. for dataset_id in dataset_ids:
  714. try:
  715. # 调用单个删除接口
  716. response = delete_dataset_endpoint(dataset_id)
  717. # 解析响应
  718. if response[1] == 200:
  719. results['success'].append(dataset_id)
  720. elif response[1] == 400 and 'models' in response[0].json:
  721. # 数据集被模型保护
  722. results['protected'].append({
  723. 'id': dataset_id,
  724. 'models': response[0].json['models']
  725. })
  726. else:
  727. results['failed'].append({
  728. 'id': dataset_id,
  729. 'reason': response[0].json.get('error', '删除失败')
  730. })
  731. except Exception as e:
  732. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  733. results['failed'].append({
  734. 'id': dataset_id,
  735. 'reason': str(e)
  736. })
  737. # 构建响应消息
  738. message = f"成功删除 {len(results['success'])} 个数据集"
  739. if results['protected']:
  740. message += f", {len(results['protected'])} 个数据集被保护"
  741. if results['failed']:
  742. message += f", {len(results['failed'])} 个数据集删除失败"
  743. return jsonify({
  744. 'message': message,
  745. 'results': results
  746. }), 200
  747. except Exception as e:
  748. logger.error(f'批量删除数据集失败: {str(e)}')
  749. return jsonify({'error': str(e)}), 500
  750. @bp.route('/batch-delete-models', methods=['POST'])
  751. def batch_delete_models():
  752. """
  753. 批量删除模型的API接口
  754. @body_param model_ids: 要删除的模型ID列表
  755. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  756. @return: JSON响应
  757. """
  758. try:
  759. data = request.get_json()
  760. model_ids = data.get('model_ids', [])
  761. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  762. if not model_ids:
  763. return jsonify({'error': '未提供模型ID列表'}), 400
  764. results = {
  765. 'success': [],
  766. 'failed': [],
  767. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  768. }
  769. for model_id in model_ids:
  770. try:
  771. # 调用单个删除接口
  772. response = delete_model(model_id, delete_dataset=delete_datasets)
  773. # 解析响应
  774. if response[1] == 200:
  775. results['success'].append(model_id)
  776. # 如果删除了关联数据集,记录数据集ID
  777. if 'dataset_info' in response[0].json:
  778. results['datasets_deleted'].append(
  779. response[0].json['dataset_info']['dataset_id']
  780. )
  781. else:
  782. results['failed'].append({
  783. 'id': model_id,
  784. 'reason': response[0].json.get('error', '删除失败')
  785. })
  786. except Exception as e:
  787. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  788. results['failed'].append({
  789. 'id': model_id,
  790. 'reason': str(e)
  791. })
  792. # 构建响应消息
  793. message = f"成功删除 {len(results['success'])} 个模型"
  794. if results['datasets_deleted']:
  795. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  796. if results['failed']:
  797. message += f", {len(results['failed'])} 个模型删除失败"
  798. return jsonify({
  799. 'message': message,
  800. 'results': results
  801. }), 200
  802. except Exception as e:
  803. logger.error(f'批量删除模型失败: {str(e)}')
  804. return jsonify({'error': str(e)}), 500
  805. @bp.route('/kriging_interpolation', methods=['POST'])
  806. def kriging_interpolation():
  807. try:
  808. data = request.get_json()
  809. required = ['file_name', 'emission_column', 'points']
  810. if not all(k in data for k in required):
  811. return jsonify({"error": "Missing parameters"}), 400
  812. # 添加坐标顺序验证
  813. points = data['points']
  814. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  815. return jsonify({"error": "Invalid points format"}), 400
  816. result = create_kriging(
  817. data['file_name'],
  818. data['emission_column'],
  819. data['points']
  820. )
  821. return jsonify(result)
  822. except Exception as e:
  823. return jsonify({"error": str(e)}), 500
  824. # 显示切换模型
  825. @bp.route('/models', methods=['GET'])
  826. def get_models():
  827. session = None
  828. try:
  829. # 创建 session
  830. Session = sessionmaker(bind=db.engine)
  831. session = Session()
  832. # 查询所有模型
  833. models = session.query(Models).all()
  834. logger.debug(f"Models found: {models}") # 打印查询的模型数据
  835. if not models:
  836. return jsonify({'message': 'No models found'}), 404
  837. # 将模型数据转换为字典列表
  838. models_list = [
  839. {
  840. 'ModelID': model.ModelID,
  841. 'ModelName': model.Model_name,
  842. 'ModelType': model.Model_type,
  843. 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  844. 'Description': model.Description,
  845. 'DatasetID': model.DatasetID,
  846. 'ModelFilePath': model.ModelFilePath,
  847. 'DataType': model.Data_type,
  848. 'PerformanceScore': model.Performance_score,
  849. 'MAE': model.MAE,
  850. 'RMSE': model.RMSE
  851. }
  852. for model in models
  853. ]
  854. return jsonify(models_list), 200
  855. except Exception as e:
  856. return jsonify({'error': str(e)}), 400
  857. finally:
  858. if session:
  859. session.close()
  860. # 定义提供数据库列表,用于展示表格的 API 接口
  861. @bp.route('/table', methods=['POST'])
  862. def get_table():
  863. data = request.get_json()
  864. table_name = data.get('table')
  865. if not table_name:
  866. return jsonify({'error': '需要表名'}), 400
  867. try:
  868. # 创建 sessionmaker 实例
  869. Session = sessionmaker(bind=db.engine)
  870. session = Session()
  871. # 动态获取表的元数据
  872. metadata = MetaData()
  873. table = Table(table_name, metadata, autoload_with=db.engine)
  874. # 从数据库中查询所有记录
  875. query = select(table)
  876. result = session.execute(query).fetchall()
  877. # 将结果转换为列表字典形式
  878. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  879. # 获取列名
  880. headers = [column.name for column in table.columns]
  881. return jsonify(rows=rows, headers=headers), 200
  882. except Exception as e:
  883. return jsonify({'error': str(e)}), 400
  884. finally:
  885. # 关闭 session
  886. session.close()
  887. @bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
  888. def get_model_scatter_data(model_id):
  889. """
  890. 获取指定模型的散点图数据(真实值vs预测值)
  891. @param model_id: 模型ID
  892. @return: JSON响应,包含散点图数据
  893. """
  894. Session = sessionmaker(bind=db.engine)
  895. session = Session()
  896. try:
  897. # 查询模型信息
  898. model = session.query(Models).filter_by(ModelID=model_id).first()
  899. if not model:
  900. return jsonify({'error': '未找到指定模型'}), 404
  901. # 加载模型
  902. with open(model.ModelFilePath, 'rb') as f:
  903. ML_model = pickle.load(f)
  904. # 根据数据类型加载测试数据
  905. if model.Data_type == 'reflux':
  906. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  907. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  908. elif model.Data_type == 'reduce':
  909. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  910. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  911. else:
  912. return jsonify({'error': '不支持的数据类型'}), 400
  913. # 获取预测值
  914. y_pred = ML_model.predict(X_test)
  915. # 生成散点图数据
  916. scatter_data = [
  917. [float(true), float(pred)]
  918. for true, pred in zip(Y_test.iloc[:, 0], y_pred)
  919. ]
  920. # 计算R²分数
  921. r2 = r2_score(Y_test, y_pred)
  922. # 获取数据范围,用于绘制对角线
  923. y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
  924. y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
  925. return jsonify({
  926. 'scatter_data': scatter_data,
  927. 'r2_score': float(r2),
  928. 'y_range': [float(y_min), float(y_max)],
  929. 'model_name': model.Model_name,
  930. 'model_type': model.Model_type
  931. }), 200
  932. except Exception as e:
  933. logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
  934. return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
  935. finally:
  936. session.close()