routes.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. import sqlite3
  2. from flask import Blueprint, request, jsonify,current_app
  3. from werkzeug.security import generate_password_hash
  4. from .model import predict, train_and_save_model, calculate_model_score
  5. import pandas as pd
  6. from . import db # 从 app 包导入 db 实例
  7. from sqlalchemy.engine.reflection import Inspector
  8. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  9. import os
  10. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  11. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  12. predict_to_Q, Q_to_t_ha, create_kriging
  13. from sqlalchemy.orm import sessionmaker
  14. import logging
  15. from sqlalchemy import text, func, MetaData, Table, select
  16. from .tasks import train_model_task
  17. from datetime import datetime
  18. # 配置日志
  19. logging.basicConfig(level=logging.DEBUG)
  20. logger = logging.getLogger(__name__)
  21. # 创建蓝图 (Blueprint),用于分离路由
  22. bp = Blueprint('routes', __name__)
  23. # 封装数据库连接函数
  24. def get_db_connection():
  25. return sqlite3.connect('software_intro.db')
  26. # 密码加密
  27. def hash_password(password):
  28. return generate_password_hash(password)
  29. def get_db():
  30. """ 获取数据库连接 """
  31. return sqlite3.connect(current_app.config['DATABASE'])
  32. # 添加一个新的辅助函数来检查数据集大小并触发训练
  33. def check_and_trigger_training(session, dataset_type, dataset_df):
  34. """
  35. 检查当前数据集大小是否跨越新的阈值点并触发训练
  36. Args:
  37. session: 数据库会话
  38. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  39. dataset_df: 数据集 DataFrame
  40. Returns:
  41. tuple: (是否触发训练, 任务ID)
  42. """
  43. try:
  44. # 根据数据集类型选择表
  45. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  46. # 获取当前记录数
  47. current_count = session.query(func.count()).select_from(table).scalar()
  48. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  49. new_records = len(dataset_df) # 需要从上层函数传入
  50. # 计算新增数据前的记录数
  51. previous_count = current_count - new_records
  52. # 设置阈值
  53. THRESHOLD = current_app.config['THRESHOLD']
  54. # 计算上一个阈值点(基于新增前的数据量)
  55. last_threshold = previous_count // THRESHOLD * THRESHOLD
  56. # 计算当前所在阈值点
  57. current_threshold = current_count // THRESHOLD * THRESHOLD
  58. # 检查是否跨越了新的阈值点
  59. if current_threshold > last_threshold and current_count >= THRESHOLD:
  60. # 触发异步训练任务
  61. task = train_model_task.delay(
  62. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  63. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  64. model_description=f'Auto trained model at {current_threshold} records threshold',
  65. data_type=dataset_type
  66. )
  67. return True, task.id
  68. return False, None
  69. except Exception as e:
  70. logging.error(f"检查并触发训练失败: {str(e)}")
  71. return False, None
  72. @bp.route('/upload-dataset', methods=['POST'])
  73. def upload_dataset():
  74. # 创建 session
  75. Session = sessionmaker(bind=db.engine)
  76. session = Session()
  77. try:
  78. if 'file' not in request.files:
  79. return jsonify({'error': 'No file part'}), 400
  80. file = request.files['file']
  81. if file.filename == '' or not allowed_file(file.filename):
  82. return jsonify({'error': 'No selected file or invalid file type'}), 400
  83. dataset_name = request.form.get('dataset_name')
  84. dataset_description = request.form.get('dataset_description', 'No description provided')
  85. dataset_type = request.form.get('dataset_type')
  86. if not dataset_type:
  87. return jsonify({'error': 'Dataset type is required'}), 400
  88. new_dataset = Datasets(
  89. Dataset_name=dataset_name,
  90. Dataset_description=dataset_description,
  91. Row_count=0,
  92. Status='Datasets_upgraded',
  93. Dataset_type=dataset_type,
  94. Uploaded_at=datetime.now()
  95. )
  96. session.add(new_dataset)
  97. session.commit()
  98. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  99. upload_folder = current_app.config['UPLOAD_FOLDER']
  100. file_path = os.path.join(upload_folder, unique_filename)
  101. file.save(file_path)
  102. dataset_df = pd.read_excel(file_path)
  103. new_dataset.Row_count = len(dataset_df)
  104. new_dataset.Status = 'excel_file_saved success'
  105. session.commit()
  106. # 处理列名
  107. dataset_df = clean_column_names(dataset_df)
  108. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  109. column_types = infer_column_types(dataset_df)
  110. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  111. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  112. # 根据 dataset_type 决定插入到哪个已有表
  113. if dataset_type == 'reduce':
  114. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  115. elif dataset_type == 'reflux':
  116. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  117. session.commit()
  118. # 在完成数据插入后,检查是否需要触发训练
  119. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  120. response_data = {
  121. 'message': f'Dataset {dataset_name} uploaded successfully!',
  122. 'dataset_id': new_dataset.Dataset_ID,
  123. 'filename': unique_filename,
  124. 'training_triggered': training_triggered
  125. }
  126. if training_triggered:
  127. response_data['task_id'] = task_id
  128. response_data['message'] += ' Auto-training has been triggered.'
  129. return jsonify(response_data), 201
  130. except Exception as e:
  131. session.rollback()
  132. logging.error('Failed to process the dataset upload:', exc_info=True)
  133. return jsonify({'error': str(e)}), 500
  134. finally:
  135. # 确保 session 总是被关闭
  136. if session:
  137. session.close()
  138. @bp.route('/train-and-save-model', methods=['POST'])
  139. def train_and_save_model_endpoint():
  140. # 创建 sessionmaker 实例
  141. Session = sessionmaker(bind=db.engine)
  142. session = Session()
  143. data = request.get_json()
  144. # 从请求中解析参数
  145. model_type = data.get('model_type')
  146. model_name = data.get('model_name')
  147. model_description = data.get('model_description')
  148. data_type = data.get('data_type')
  149. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  150. try:
  151. # 调用训练和保存模型的函数
  152. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  153. model_id = result[1] if result else None
  154. # 计算模型评分
  155. if model_id:
  156. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  157. if model_info:
  158. score = calculate_model_score(model_info)
  159. # 更新模型评分
  160. model_info.Performance_score = score
  161. session.commit()
  162. result = {'model_id': model_id, 'model_score': score}
  163. # 返回成功响应
  164. return jsonify({
  165. 'message': 'Model trained and saved successfully',
  166. 'result': result
  167. }), 200
  168. except Exception as e:
  169. session.rollback()
  170. logging.error('Failed to process the model training:', exc_info=True)
  171. return jsonify({
  172. 'error': 'Failed to train and save model',
  173. 'message': str(e)
  174. }), 500
  175. finally:
  176. session.close()
  177. @bp.route('/predict', methods=['POST'])
  178. def predict_route():
  179. # 创建 sessionmaker 实例
  180. Session = sessionmaker(bind=db.engine)
  181. session = Session()
  182. try:
  183. data = request.get_json()
  184. model_id = data.get('model_id') # 提取模型名称
  185. parameters = data.get('parameters', {}) # 提取所有变量
  186. # 根据model_id获取模型Data_type
  187. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  188. if not model_info:
  189. return jsonify({'error': 'Model not found'}), 404
  190. data_type = model_info.Data_type
  191. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  192. # 如果为reduce,则不需要传入target_ph
  193. if data_type == 'reduce':
  194. # 获取传入的init_ph、target_ph参数
  195. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  196. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  197. # 从输入数据中删除'target_pH'列
  198. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  199. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  200. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  201. if data_type == 'reduce':
  202. predictions = predictions[0]
  203. # 将预测结果转换为Q
  204. Q = predict_to_Q(predictions, init_ph, target_ph)
  205. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  206. print(predictions)
  207. return jsonify({'result': predictions}), 200
  208. except Exception as e:
  209. logging.error('Failed to predict:', exc_info=True)
  210. return jsonify({'error': str(e)}), 400
  211. # 为指定模型计算评分Performance_score,需要提供model_id
  212. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  213. def score_model(model_id):
  214. # 创建 sessionmaker 实例
  215. Session = sessionmaker(bind=db.engine)
  216. session = Session()
  217. try:
  218. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  219. if not model_info:
  220. return jsonify({'error': 'Model not found'}), 404
  221. # 计算模型评分
  222. score = calculate_model_score(model_info)
  223. # 更新模型记录中的评分
  224. model_info.Performance_score = score
  225. session.commit()
  226. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  227. except Exception as e:
  228. logging.error('Failed to process the dataset upload:', exc_info=True)
  229. return jsonify({'error': str(e)}), 400
  230. finally:
  231. session.close()
  232. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  233. def delete_dataset_endpoint(dataset_id):
  234. """
  235. 删除数据集的API接口
  236. @param dataset_id: 要删除的数据集ID
  237. @return: JSON响应
  238. """
  239. # 创建 sessionmaker 实例
  240. Session = sessionmaker(bind=db.engine)
  241. session = Session()
  242. try:
  243. # 查询数据集
  244. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  245. if not dataset:
  246. return jsonify({'error': '未找到数据集'}), 404
  247. # 检查是否有模型使用了该数据集
  248. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  249. if models_using_dataset:
  250. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  251. return jsonify({
  252. 'error': '无法删除数据集,因为以下模型正在使用它',
  253. 'models': models_info
  254. }), 400
  255. # 删除Excel文件
  256. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  257. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  258. if os.path.exists(file_path):
  259. try:
  260. os.remove(file_path)
  261. except OSError as e:
  262. logger.error(f'删除文件失败: {str(e)}')
  263. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  264. # 删除数据表
  265. table_name = f"dataset_{dataset.Dataset_ID}"
  266. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  267. # 删除数据集记录
  268. session.delete(dataset)
  269. session.commit()
  270. return jsonify({
  271. 'message': '数据集删除成功',
  272. 'deleted_files': [filename]
  273. }), 200
  274. except Exception as e:
  275. session.rollback()
  276. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  277. return jsonify({'error': str(e)}), 500
  278. finally:
  279. session.close()
  280. @bp.route('/tables', methods=['GET'])
  281. def list_tables():
  282. engine = db.engine # 使用 db 实例的 engine
  283. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  284. table_names = inspector.get_table_names() # 获取所有表名
  285. return jsonify(table_names) # 以 JSON 形式返回表名列表
  286. @bp.route('/models/<int:model_id>', methods=['GET'])
  287. def get_model(model_id):
  288. """
  289. 获取单个模型信息的API接口
  290. @param model_id: 模型ID
  291. @return: JSON响应
  292. """
  293. Session = sessionmaker(bind=db.engine)
  294. session = Session()
  295. try:
  296. model = session.query(Models).filter_by(ModelID=model_id).first()
  297. if model:
  298. return jsonify({
  299. 'ModelID': model.ModelID,
  300. 'Model_name': model.Model_name,
  301. 'Model_type': model.Model_type,
  302. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  303. 'Description': model.Description,
  304. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  305. 'MAE': float(model.MAE) if model.MAE else None,
  306. 'RMSE': float(model.RMSE) if model.RMSE else None,
  307. 'Data_type': model.Data_type
  308. })
  309. else:
  310. return jsonify({'message': '未找到模型'}), 404
  311. except Exception as e:
  312. logger.error(f'获取模型信息失败: {str(e)}')
  313. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  314. finally:
  315. session.close()
  316. @bp.route('/model-parameters', methods=['GET'])
  317. def get_all_model_parameters():
  318. """
  319. 获取所有模型参数的API接口
  320. @return: JSON响应
  321. """
  322. Session = sessionmaker(bind=db.engine)
  323. session = Session()
  324. try:
  325. parameters = session.query(ModelParameters).all()
  326. if parameters:
  327. result = [
  328. {
  329. 'ParamID': param.ParamID,
  330. 'ModelID': param.ModelID,
  331. 'ParamName': param.ParamName,
  332. 'ParamValue': param.ParamValue
  333. }
  334. for param in parameters
  335. ]
  336. return jsonify(result)
  337. else:
  338. return jsonify({'message': '未找到任何参数'}), 404
  339. except Exception as e:
  340. logger.error(f'获取所有模型参数失败: {str(e)}')
  341. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  342. finally:
  343. session.close()
  344. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  345. def get_model_parameters(model_id):
  346. try:
  347. model = Models.query.filter_by(ModelID=model_id).first()
  348. if model:
  349. # 获取该模型的所有参数
  350. parameters = [
  351. {
  352. 'ParamID': param.ParamID,
  353. 'ParamName': param.ParamName,
  354. 'ParamValue': param.ParamValue
  355. }
  356. for param in model.parameters
  357. ]
  358. # 返回模型参数信息
  359. return jsonify({
  360. 'ModelID': model.ModelID,
  361. 'ModelName': model.ModelName,
  362. 'ModelType': model.ModelType,
  363. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  364. 'Description': model.Description,
  365. 'Parameters': parameters
  366. })
  367. else:
  368. return jsonify({'message': 'Model not found'}), 404
  369. except Exception as e:
  370. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  371. @bp.route('/train-model-async', methods=['POST'])
  372. def train_model_async():
  373. """
  374. 异步训练模型的API接口
  375. """
  376. try:
  377. data = request.get_json()
  378. # 从请求中获取参数
  379. model_type = data.get('model_type')
  380. model_name = data.get('model_name')
  381. model_description = data.get('model_description')
  382. data_type = data.get('data_type')
  383. dataset_id = data.get('dataset_id', None)
  384. # 验证必要参数
  385. if not all([model_type, model_name, data_type]):
  386. return jsonify({
  387. 'error': 'Missing required parameters'
  388. }), 400
  389. # 如果提供了dataset_id,验证数据集是否存在
  390. if dataset_id:
  391. Session = sessionmaker(bind=db.engine)
  392. session = Session()
  393. try:
  394. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  395. if not dataset:
  396. return jsonify({
  397. 'error': f'Dataset with ID {dataset_id} not found'
  398. }), 404
  399. finally:
  400. session.close()
  401. # 启动异步任务
  402. task = train_model_task.delay(
  403. model_type=model_type,
  404. model_name=model_name,
  405. model_description=model_description,
  406. data_type=data_type,
  407. dataset_id=dataset_id
  408. )
  409. # 返回任务ID
  410. return jsonify({
  411. 'task_id': task.id,
  412. 'message': 'Model training started'
  413. }), 202
  414. except Exception as e:
  415. logging.error('Failed to start async training task:', exc_info=True)
  416. return jsonify({
  417. 'error': str(e)
  418. }), 500
  419. @bp.route('/task-status/<task_id>', methods=['GET'])
  420. def get_task_status(task_id):
  421. """
  422. 获取异步任务状态的API接口
  423. """
  424. try:
  425. task = train_model_task.AsyncResult(task_id)
  426. if task.state == 'PENDING':
  427. response = {
  428. 'state': task.state,
  429. 'status': 'Task is waiting for execution'
  430. }
  431. elif task.state == 'FAILURE':
  432. response = {
  433. 'state': task.state,
  434. 'status': 'Task failed',
  435. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  436. }
  437. elif task.state == 'SUCCESS':
  438. response = {
  439. 'state': task.state,
  440. 'status': 'Task completed successfully',
  441. 'result': task.get()
  442. }
  443. else:
  444. response = {
  445. 'state': task.state,
  446. 'status': 'Task is in progress'
  447. }
  448. return jsonify(response), 200
  449. except Exception as e:
  450. return jsonify({
  451. 'error': str(e)
  452. }), 500
  453. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  454. def delete_model_route(model_id):
  455. # 将URL参数转换为布尔值
  456. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  457. # 调用原始函数
  458. return delete_model(model_id, delete_dataset=delete_dataset_param)
  459. def delete_model(model_id, delete_dataset=False):
  460. """
  461. 删除指定模型的API接口
  462. @param model_id: 要删除的模型ID
  463. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  464. @return: JSON响应
  465. """
  466. Session = sessionmaker(bind=db.engine)
  467. session = Session()
  468. try:
  469. # 查询模型信息
  470. model = session.query(Models).filter_by(ModelID=model_id).first()
  471. if not model:
  472. return jsonify({'error': '未找到指定模型'}), 404
  473. dataset_id = model.DatasetID
  474. # 1. 先删除模型记录
  475. session.delete(model)
  476. session.commit()
  477. # 2. 删除模型文件
  478. model_file = f"rf_model_{model_id}.pkl"
  479. model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
  480. if os.path.exists(model_path):
  481. try:
  482. os.remove(model_path)
  483. except OSError as e:
  484. # 如果删除文件失败,回滚数据库操作
  485. session.rollback()
  486. logger.error(f'删除模型文件失败: {str(e)}')
  487. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  488. # 3. 如果需要删除关联的数据集
  489. if delete_dataset and dataset_id:
  490. try:
  491. dataset_response = delete_dataset_endpoint(dataset_id)
  492. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  493. # 如果删除数据集失败,回滚之前的操作
  494. session.rollback()
  495. return jsonify({
  496. 'error': '删除关联数据集失败',
  497. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  498. }), 500
  499. except Exception as e:
  500. session.rollback()
  501. logger.error(f'删除关联数据集失败: {str(e)}')
  502. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  503. response_data = {
  504. 'message': '模型删除成功',
  505. 'deleted_files': [model_file]
  506. }
  507. if delete_dataset:
  508. response_data['dataset_info'] = {
  509. 'dataset_id': dataset_id,
  510. 'message': '关联数据集已删除'
  511. }
  512. return jsonify(response_data), 200
  513. except Exception as e:
  514. session.rollback()
  515. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  516. return jsonify({'error': str(e)}), 500
  517. finally:
  518. session.close()
  519. # 添加一个新的API端点来清空指定数据集
  520. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  521. def clear_dataset(data_type):
  522. """
  523. 清空指定类型的数据集并递增计数
  524. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  525. @return: JSON响应
  526. """
  527. # 创建 sessionmaker 实例
  528. Session = sessionmaker(bind=db.engine)
  529. session = Session()
  530. try:
  531. # 根据数据集类型选择表
  532. if data_type == 'reduce':
  533. table = CurrentReduce
  534. table_name = 'current_reduce'
  535. elif data_type == 'reflux':
  536. table = CurrentReflux
  537. table_name = 'current_reflux'
  538. else:
  539. return jsonify({'error': '无效的数据集类型'}), 400
  540. # 清空表内容
  541. session.query(table).delete()
  542. # 重置自增主键计数器
  543. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  544. session.commit()
  545. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  546. except Exception as e:
  547. session.rollback()
  548. return jsonify({'error': str(e)}), 500
  549. finally:
  550. session.close()
  551. @bp.route('/update-threshold', methods=['POST'])
  552. def update_threshold():
  553. """
  554. 更新训练阈值的API接口
  555. @body_param threshold: 新的阈值值(整数)
  556. @return: JSON响应
  557. """
  558. try:
  559. data = request.get_json()
  560. new_threshold = data.get('threshold')
  561. # 验证新阈值
  562. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  563. return jsonify({
  564. 'error': '无效的阈值值,必须为正数'
  565. }), 400
  566. # 更新当前应用的阈值配置
  567. current_app.config['THRESHOLD'] = int(new_threshold)
  568. return jsonify({
  569. 'success': True,
  570. 'message': f'阈值已更新为 {new_threshold}',
  571. 'new_threshold': new_threshold
  572. })
  573. except Exception as e:
  574. logging.error(f"更新阈值失败: {str(e)}")
  575. return jsonify({
  576. 'error': f'更新阈值失败: {str(e)}'
  577. }), 500
  578. @bp.route('/get-threshold', methods=['GET'])
  579. def get_threshold():
  580. """
  581. 获取当前训练阈值的API接口
  582. @return: JSON响应
  583. """
  584. try:
  585. current_threshold = current_app.config['THRESHOLD']
  586. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  587. return jsonify({
  588. 'current_threshold': current_threshold,
  589. 'default_threshold': default_threshold
  590. })
  591. except Exception as e:
  592. logging.error(f"获取阈值失败: {str(e)}")
  593. return jsonify({
  594. 'error': f'获取阈值失败: {str(e)}'
  595. }), 500
  596. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  597. def set_current_dataset(data_type, dataset_id):
  598. """
  599. 将指定数据集设置为current数据集
  600. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  601. @param dataset_id: 要设置为current的数据集ID
  602. @return: JSON响应
  603. """
  604. Session = sessionmaker(bind=db.engine)
  605. session = Session()
  606. try:
  607. # 验证数据集存在且类型匹配
  608. dataset = session.query(Datasets)\
  609. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  610. .first()
  611. if not dataset:
  612. return jsonify({
  613. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  614. }), 404
  615. # 根据数据类型选择表
  616. if data_type == 'reduce':
  617. table = CurrentReduce
  618. table_name = 'current_reduce'
  619. elif data_type == 'reflux':
  620. table = CurrentReflux
  621. table_name = 'current_reflux'
  622. else:
  623. return jsonify({'error': '无效的数据集类型'}), 400
  624. # 清空current表
  625. session.query(table).delete()
  626. # 重置自增主键计数器
  627. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  628. # 从指定数据集复制数据到current表
  629. dataset_table_name = f"dataset_{dataset_id}"
  630. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  631. session.execute(copy_sql)
  632. session.commit()
  633. return jsonify({
  634. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  635. 'dataset_id': dataset_id,
  636. 'dataset_name': dataset.Dataset_name,
  637. 'row_count': dataset.Row_count
  638. }), 200
  639. except Exception as e:
  640. session.rollback()
  641. logger.error(f'设置current数据集失败: {str(e)}')
  642. return jsonify({'error': str(e)}), 500
  643. finally:
  644. session.close()
  645. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  646. def get_model_history(data_type):
  647. """
  648. 获取模型训练历史数据的API接口
  649. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  650. @return: JSON响应,包含时间序列的模型性能数据
  651. """
  652. Session = sessionmaker(bind=db.engine)
  653. session = Session()
  654. try:
  655. # 查询所有自动生成的数据集,按时间排序
  656. datasets = session.query(Datasets).filter(
  657. Datasets.Dataset_type == data_type,
  658. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  659. ).order_by(Datasets.Uploaded_at).all()
  660. history_data = []
  661. for dataset in datasets:
  662. # 查找对应的自动训练模型
  663. model = session.query(Models).filter(
  664. Models.DatasetID == dataset.Dataset_ID,
  665. Models.Model_name.like(f'auto_trained_{data_type}_%')
  666. ).first()
  667. if model and model.Performance_score is not None:
  668. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  669. created_at = model.Created_at.isoformat() if model.Created_at else None
  670. history_data.append({
  671. 'dataset_id': dataset.Dataset_ID,
  672. 'row_count': dataset.Row_count,
  673. 'model_id': model.ModelID,
  674. 'model_name': model.Model_name,
  675. 'performance_score': float(model.Performance_score),
  676. 'timestamp': created_at
  677. })
  678. # 按时间戳排序
  679. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  680. # 构建返回数据,分离各个指标序列便于前端绘图
  681. response_data = {
  682. 'data_type': data_type,
  683. 'timestamps': [item['timestamp'] for item in history_data],
  684. 'row_counts': [item['row_count'] for item in history_data],
  685. 'performance_scores': [item['performance_score'] for item in history_data],
  686. 'model_details': history_data # 保留完整数据供前端使用
  687. }
  688. return jsonify(response_data), 200
  689. except Exception as e:
  690. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  691. return jsonify({'error': str(e)}), 500
  692. finally:
  693. session.close()
  694. @bp.route('/batch-delete-datasets', methods=['POST'])
  695. def batch_delete_datasets():
  696. """
  697. 批量删除数据集的API接口
  698. @body_param dataset_ids: 要删除的数据集ID列表
  699. @return: JSON响应
  700. """
  701. try:
  702. data = request.get_json()
  703. dataset_ids = data.get('dataset_ids', [])
  704. if not dataset_ids:
  705. return jsonify({'error': '未提供数据集ID列表'}), 400
  706. results = {
  707. 'success': [],
  708. 'failed': [],
  709. 'protected': [] # 被模型使用的数据集
  710. }
  711. for dataset_id in dataset_ids:
  712. try:
  713. # 调用单个删除接口
  714. response = delete_dataset_endpoint(dataset_id)
  715. # 解析响应
  716. if response[1] == 200:
  717. results['success'].append(dataset_id)
  718. elif response[1] == 400 and 'models' in response[0].json:
  719. # 数据集被模型保护
  720. results['protected'].append({
  721. 'id': dataset_id,
  722. 'models': response[0].json['models']
  723. })
  724. else:
  725. results['failed'].append({
  726. 'id': dataset_id,
  727. 'reason': response[0].json.get('error', '删除失败')
  728. })
  729. except Exception as e:
  730. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  731. results['failed'].append({
  732. 'id': dataset_id,
  733. 'reason': str(e)
  734. })
  735. # 构建响应消息
  736. message = f"成功删除 {len(results['success'])} 个数据集"
  737. if results['protected']:
  738. message += f", {len(results['protected'])} 个数据集被保护"
  739. if results['failed']:
  740. message += f", {len(results['failed'])} 个数据集删除失败"
  741. return jsonify({
  742. 'message': message,
  743. 'results': results
  744. }), 200
  745. except Exception as e:
  746. logger.error(f'批量删除数据集失败: {str(e)}')
  747. return jsonify({'error': str(e)}), 500
  748. @bp.route('/batch-delete-models', methods=['POST'])
  749. def batch_delete_models():
  750. """
  751. 批量删除模型的API接口
  752. @body_param model_ids: 要删除的模型ID列表
  753. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  754. @return: JSON响应
  755. """
  756. try:
  757. data = request.get_json()
  758. model_ids = data.get('model_ids', [])
  759. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  760. if not model_ids:
  761. return jsonify({'error': '未提供模型ID列表'}), 400
  762. results = {
  763. 'success': [],
  764. 'failed': [],
  765. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  766. }
  767. for model_id in model_ids:
  768. try:
  769. # 调用单个删除接口
  770. response = delete_model(model_id, delete_dataset=delete_datasets)
  771. # 解析响应
  772. if response[1] == 200:
  773. results['success'].append(model_id)
  774. # 如果删除了关联数据集,记录数据集ID
  775. if 'dataset_info' in response[0].json:
  776. results['datasets_deleted'].append(
  777. response[0].json['dataset_info']['dataset_id']
  778. )
  779. else:
  780. results['failed'].append({
  781. 'id': model_id,
  782. 'reason': response[0].json.get('error', '删除失败')
  783. })
  784. except Exception as e:
  785. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  786. results['failed'].append({
  787. 'id': model_id,
  788. 'reason': str(e)
  789. })
  790. # 构建响应消息
  791. message = f"成功删除 {len(results['success'])} 个模型"
  792. if results['datasets_deleted']:
  793. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  794. if results['failed']:
  795. message += f", {len(results['failed'])} 个模型删除失败"
  796. return jsonify({
  797. 'message': message,
  798. 'results': results
  799. }), 200
  800. except Exception as e:
  801. logger.error(f'批量删除模型失败: {str(e)}')
  802. return jsonify({'error': str(e)}), 500
  803. @bp.route('/kriging_interpolation', methods=['POST'])
  804. def kriging_interpolation():
  805. try:
  806. data = request.get_json()
  807. required = ['file_name', 'emission_column', 'points']
  808. if not all(k in data for k in required):
  809. return jsonify({"error": "Missing parameters"}), 400
  810. # 添加坐标顺序验证
  811. points = data['points']
  812. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  813. return jsonify({"error": "Invalid points format"}), 400
  814. result = create_kriging(
  815. data['file_name'],
  816. data['emission_column'],
  817. data['points']
  818. )
  819. return jsonify(result)
  820. except Exception as e:
  821. return jsonify({"error": str(e)}), 500
  822. # 显示切换模型
  823. @bp.route('/models', methods=['GET'])
  824. def get_models():
  825. session = None
  826. try:
  827. # 创建 session
  828. Session = sessionmaker(bind=db.engine)
  829. session = Session()
  830. # 查询所有模型
  831. models = session.query(Models).all()
  832. logger.debug(f"Models found: {models}") # 打印查询的模型数据
  833. if not models:
  834. return jsonify({'message': 'No models found'}), 404
  835. # 将模型数据转换为字典列表
  836. models_list = [
  837. {
  838. 'ModelID': model.ModelID,
  839. 'ModelName': model.Model_name,
  840. 'ModelType': model.Model_type,
  841. 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  842. 'Description': model.Description,
  843. 'DatasetID': model.DatasetID,
  844. 'ModelFilePath': model.ModelFilePath,
  845. 'DataType': model.Data_type,
  846. 'PerformanceScore': model.Performance_score
  847. }
  848. for model in models
  849. ]
  850. return jsonify(models_list), 200
  851. except Exception as e:
  852. return jsonify({'error': str(e)}), 400
  853. finally:
  854. if session:
  855. session.close()
  856. # 定义提供数据库列表,用于展示表格的 API 接口
  857. @bp.route('/table', methods=['POST'])
  858. def get_table():
  859. data = request.get_json()
  860. table_name = data.get('table')
  861. if not table_name:
  862. return jsonify({'error': '需要表名'}), 400
  863. try:
  864. # 创建 sessionmaker 实例
  865. Session = sessionmaker(bind=db.engine)
  866. session = Session()
  867. # 动态获取表的元数据
  868. metadata = MetaData()
  869. table = Table(table_name, metadata, autoload_with=db.engine)
  870. # 从数据库中查询所有记录
  871. query = select(table)
  872. result = session.execute(query).fetchall()
  873. # 将结果转换为列表字典形式
  874. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  875. # 获取列名
  876. headers = [column.name for column in table.columns]
  877. return jsonify(rows=rows, headers=headers), 200
  878. except Exception as e:
  879. return jsonify({'error': str(e)}), 400
  880. finally:
  881. # 关闭 session
  882. session.close()