routes.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. import sqlite3
  2. from io import BytesIO
  3. import pickle
  4. from flask import Blueprint, request, jsonify, current_app, send_file
  5. from werkzeug.security import check_password_hash, generate_password_hash
  6. from werkzeug.utils import secure_filename
  7. from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
  8. import pandas as pd
  9. from . import db # 从 app 包导入 db 实例
  10. from sqlalchemy.engine.reflection import Inspector
  11. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  12. import os
  13. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  14. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  15. predict_to_Q, Q_to_t_ha, create_kriging
  16. from sqlalchemy.orm import sessionmaker
  17. import logging
  18. from sqlalchemy import text, select, MetaData, Table, func
  19. from .tasks import train_model_task
  20. from datetime import datetime
  21. from sklearn.metrics import r2_score
  22. # 配置日志
  23. logging.basicConfig(level=logging.DEBUG)
  24. logger = logging.getLogger(__name__)
  25. # 创建蓝图 (Blueprint),用于分离路由
  26. bp = Blueprint('routes', __name__)
  27. # 密码加密
  28. def hash_password(password):
  29. return generate_password_hash(password)
  30. def get_db():
  31. """ 获取数据库连接 """
  32. return sqlite3.connect(current_app.config['DATABASE'])
  33. # 添加一个新的辅助函数来检查数据集大小并触发训练
  34. def check_and_trigger_training(session, dataset_type, dataset_df):
  35. """
  36. 检查当前数据集大小是否跨越新的阈值点并触发训练
  37. Args:
  38. session: 数据库会话
  39. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  40. dataset_df: 数据集 DataFrame
  41. Returns:
  42. tuple: (是否触发训练, 任务ID)
  43. """
  44. try:
  45. # 根据数据集类型选择表
  46. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  47. # 获取当前记录数
  48. current_count = session.query(func.count()).select_from(table).scalar()
  49. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  50. new_records = len(dataset_df) # 需要从上层函数传入
  51. # 计算新增数据前的记录数
  52. previous_count = current_count - new_records
  53. # 根据数据集类型选择阈值
  54. if dataset_type == 'reduce':
  55. THRESHOLD = current_app.config['THRESHOLD_REDUCE']
  56. else: # reflux
  57. THRESHOLD = current_app.config['THRESHOLD_REFLUX']
  58. # 计算上一个阈值点(基于新增前的数据量)
  59. last_threshold = previous_count // THRESHOLD * THRESHOLD
  60. # 计算当前所在阈值点
  61. current_threshold = current_count // THRESHOLD * THRESHOLD
  62. # 检查是否跨越了新的阈值点
  63. if current_threshold > last_threshold and current_count >= THRESHOLD:
  64. # 触发异步训练任务
  65. task = train_model_task.delay(
  66. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  67. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  68. model_description=f'Auto trained model at {current_threshold} records threshold',
  69. data_type=dataset_type
  70. )
  71. return True, task.id
  72. return False, None
  73. except Exception as e:
  74. logging.error(f"检查并触发训练失败: {str(e)}")
  75. return False, None
  76. @bp.route('/upload-dataset', methods=['POST'])
  77. def upload_dataset():
  78. # 创建 session
  79. Session = sessionmaker(bind=db.engine)
  80. session = Session()
  81. try:
  82. if 'file' not in request.files:
  83. return jsonify({'error': 'No file part'}), 400
  84. file = request.files['file']
  85. if file.filename == '' or not allowed_file(file.filename):
  86. return jsonify({'error': 'No selected file or invalid file type'}), 400
  87. dataset_name = request.form.get('dataset_name')
  88. dataset_description = request.form.get('dataset_description', 'No description provided')
  89. dataset_type = request.form.get('dataset_type')
  90. if not dataset_type:
  91. return jsonify({'error': 'Dataset type is required'}), 400
  92. new_dataset = Datasets(
  93. Dataset_name=dataset_name,
  94. Dataset_description=dataset_description,
  95. Row_count=0,
  96. Status='Datasets_upgraded',
  97. Dataset_type=dataset_type,
  98. Uploaded_at=datetime.now()
  99. )
  100. session.add(new_dataset)
  101. session.commit()
  102. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  103. upload_folder = current_app.config['UPLOAD_FOLDER']
  104. file_path = os.path.join(upload_folder, unique_filename)
  105. file.save(file_path)
  106. dataset_df = pd.read_excel(file_path)
  107. new_dataset.Row_count = len(dataset_df)
  108. new_dataset.Status = 'excel_file_saved success'
  109. session.commit()
  110. # 处理列名
  111. dataset_df = clean_column_names(dataset_df)
  112. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  113. column_types = infer_column_types(dataset_df)
  114. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  115. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  116. # 去除上传数据集内部的重复项
  117. original_count = len(dataset_df)
  118. dataset_df = dataset_df.drop_duplicates()
  119. duplicates_in_file = original_count - len(dataset_df)
  120. # 检查与现有数据的重复
  121. duplicates_with_existing = 0
  122. if dataset_type in ['reduce', 'reflux']:
  123. # 确定表名
  124. table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
  125. # 从表加载现有数据
  126. existing_data = pd.read_sql_table(table_name, session.bind)
  127. if 'id' in existing_data.columns:
  128. existing_data = existing_data.drop('id', axis=1)
  129. # 确定用于比较的列
  130. compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
  131. # 计算重复行数
  132. original_df_len = len(dataset_df)
  133. # 使用concat和drop_duplicates找出非重复行
  134. all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
  135. duplicates_mask = all_data.duplicated(keep='first')
  136. duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
  137. # 保留非重复行
  138. dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
  139. logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
  140. # 检查与测试集的重叠
  141. test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
  142. # 如果有与测试集重叠的数据,从数据集中移除
  143. if test_overlap_count > 0:
  144. # 创建一个布尔掩码,标记不在重叠索引中的行
  145. mask = ~dataset_df.index.isin(test_overlap_indices)
  146. # 应用掩码,只保留不重叠的行
  147. dataset_df = dataset_df[mask]
  148. logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
  149. # 根据 dataset_type 决定插入到哪个已有表
  150. if dataset_type == 'reduce':
  151. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  152. elif dataset_type == 'reflux':
  153. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  154. session.commit()
  155. # 在完成数据插入后,检查是否需要触发训练
  156. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  157. response_data = {
  158. 'message': f'数据集 {dataset_name} 上传成功!',
  159. 'dataset_id': new_dataset.Dataset_ID,
  160. 'filename': unique_filename,
  161. 'training_triggered': training_triggered,
  162. 'data_stats': {
  163. 'original_count': original_count,
  164. 'duplicates_in_file': duplicates_in_file,
  165. 'duplicates_with_existing': duplicates_with_existing,
  166. 'test_overlap_count': test_overlap_count,
  167. 'final_count': len(dataset_df)
  168. }
  169. }
  170. if training_triggered:
  171. response_data['task_id'] = task_id
  172. response_data['message'] += ' 自动训练已触发。'
  173. # 添加去重信息到消息中
  174. if duplicates_with_existing > 0:
  175. response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
  176. # 添加测试集重叠信息到消息中
  177. if test_overlap_count > 0:
  178. response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
  179. return jsonify(response_data), 201
  180. except Exception as e:
  181. session.rollback()
  182. logging.error('Failed to process the dataset upload:', exc_info=True)
  183. return jsonify({'error': str(e)}), 500
  184. finally:
  185. # 确保 session 总是被关闭
  186. if session:
  187. session.close()
  188. @bp.route('/train-and-save-model', methods=['POST'])
  189. def train_and_save_model_endpoint():
  190. # 创建 sessionmaker 实例
  191. Session = sessionmaker(bind=db.engine)
  192. session = Session()
  193. data = request.get_json()
  194. # 从请求中解析参数
  195. model_type = data.get('model_type')
  196. model_name = data.get('model_name')
  197. model_description = data.get('model_description')
  198. data_type = data.get('data_type')
  199. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  200. try:
  201. # 调用训练和保存模型的函数
  202. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  203. model_id = result[1] if result else None
  204. # 计算模型评分
  205. if model_id:
  206. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  207. if model_info:
  208. # 计算多种评分指标
  209. score_metrics = calculate_model_score(model_info)
  210. # 更新模型评分
  211. model_info.Performance_score = score_metrics['r2']
  212. # 添加新的评分指标到数据库
  213. model_info.MAE = score_metrics['mae']
  214. model_info.RMSE = score_metrics['rmse']
  215. # CV_score 已在 train_and_save_model 中设置,此处不再更新
  216. session.commit()
  217. result = {
  218. 'model_id': model_id,
  219. 'model_score': score_metrics['r2'],
  220. 'mae': score_metrics['mae'],
  221. 'rmse': score_metrics['rmse'],
  222. 'cv_score': result[3]
  223. }
  224. # 返回成功响应
  225. return jsonify({
  226. 'message': 'Model trained and saved successfully',
  227. 'result': result
  228. }), 200
  229. except Exception as e:
  230. session.rollback()
  231. logging.error('Failed to process the model training:', exc_info=True)
  232. return jsonify({
  233. 'error': 'Failed to train and save model',
  234. 'message': str(e)
  235. }), 500
  236. finally:
  237. session.close()
  238. @bp.route('/predict', methods=['POST'])
  239. def predict_route():
  240. # 创建 sessionmaker 实例
  241. Session = sessionmaker(bind=db.engine)
  242. session = Session()
  243. try:
  244. data = request.get_json()
  245. model_id = data.get('model_id') # 提取模型名称
  246. parameters = data.get('parameters', {}) # 提取所有变量
  247. # 根据model_id获取模型Data_type
  248. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  249. if not model_info:
  250. return jsonify({'error': 'Model not found'}), 404
  251. data_type = model_info.Data_type
  252. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  253. # 如果为reduce,则不需要传入target_ph
  254. if data_type == 'reduce':
  255. # 获取传入的init_ph、target_ph参数
  256. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  257. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  258. # 从输入数据中删除'target_pH'列
  259. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  260. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  261. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  262. if data_type == 'reduce':
  263. predictions = predictions[0]
  264. # 将预测结果转换为Q
  265. Q = predict_to_Q(predictions, init_ph, target_ph)
  266. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  267. print(predictions)
  268. return jsonify({'result': predictions}), 200
  269. except Exception as e:
  270. logging.error('Failed to predict:', exc_info=True)
  271. return jsonify({'error': str(e)}), 400
  272. # 为指定模型计算指标评分,需要提供model_id
  273. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  274. def score_model(model_id):
  275. # 创建 sessionmaker 实例
  276. Session = sessionmaker(bind=db.engine)
  277. session = Session()
  278. try:
  279. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  280. if not model_info:
  281. return jsonify({'error': 'Model not found'}), 404
  282. # 计算模型评分
  283. score_metrics = calculate_model_score(model_info)
  284. # 更新模型记录中的评分(不包括交叉验证得分)
  285. model_info.Performance_score = score_metrics['r2']
  286. model_info.MAE = score_metrics['mae']
  287. model_info.RMSE = score_metrics['rmse']
  288. session.commit()
  289. return jsonify({
  290. 'message': 'Model scored successfully',
  291. 'r2_score': score_metrics['r2'],
  292. 'mae': score_metrics['mae'],
  293. 'rmse': score_metrics['rmse'],
  294. }), 200
  295. except Exception as e:
  296. logging.error('Failed to process model scoring:', exc_info=True)
  297. return jsonify({'error': str(e)}), 400
  298. finally:
  299. session.close()
  300. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  301. def delete_dataset_endpoint(dataset_id):
  302. """
  303. 删除数据集的API接口
  304. @param dataset_id: 要删除的数据集ID
  305. @return: JSON响应
  306. """
  307. # 创建 sessionmaker 实例
  308. Session = sessionmaker(bind=db.engine)
  309. session = Session()
  310. try:
  311. # 查询数据集
  312. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  313. if not dataset:
  314. return jsonify({'error': '未找到数据集'}), 404
  315. # 检查是否有模型使用了该数据集
  316. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  317. if models_using_dataset:
  318. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  319. return jsonify({
  320. 'error': '无法删除数据集,因为以下模型正在使用它',
  321. 'models': models_info
  322. }), 400
  323. # 删除Excel文件
  324. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  325. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  326. if os.path.exists(file_path):
  327. try:
  328. os.remove(file_path)
  329. except OSError as e:
  330. logger.error(f'删除文件失败: {str(e)}')
  331. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  332. # 删除数据表
  333. table_name = f"dataset_{dataset.Dataset_ID}"
  334. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  335. # 删除数据集记录
  336. session.delete(dataset)
  337. session.commit()
  338. return jsonify({
  339. 'message': '数据集删除成功',
  340. 'deleted_files': [filename]
  341. }), 200
  342. except Exception as e:
  343. session.rollback()
  344. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  345. return jsonify({'error': str(e)}), 500
  346. finally:
  347. session.close()
  348. @bp.route('/tables', methods=['GET'])
  349. def list_tables():
  350. engine = db.engine # 使用 db 实例的 engine
  351. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  352. table_names = inspector.get_table_names() # 获取所有表名
  353. return jsonify(table_names) # 以 JSON 形式返回表名列表
  354. @bp.route('/models/<int:model_id>', methods=['GET'])
  355. def get_model(model_id):
  356. """
  357. 获取单个模型信息的API接口
  358. @param model_id: 模型ID
  359. @return: JSON响应
  360. """
  361. Session = sessionmaker(bind=db.engine)
  362. session = Session()
  363. try:
  364. model = session.query(Models).filter_by(ModelID=model_id).first()
  365. if model:
  366. return jsonify({
  367. 'ModelID': model.ModelID,
  368. 'Model_name': model.Model_name,
  369. 'Model_type': model.Model_type,
  370. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  371. 'Description': model.Description,
  372. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  373. 'MAE': float(model.MAE) if model.MAE else None,
  374. 'CV_score': float(model.CV_score) if model.CV_score else None,
  375. 'RMSE': float(model.RMSE) if model.RMSE else None,
  376. 'Data_type': model.Data_type
  377. })
  378. else:
  379. return jsonify({'message': '未找到模型'}), 404
  380. except Exception as e:
  381. logger.error(f'获取模型信息失败: {str(e)}')
  382. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  383. finally:
  384. session.close()
  385. @bp.route('/models', methods=['GET'])
  386. def get_all_models():
  387. """
  388. 获取所有模型信息的API接口
  389. @return: JSON响应
  390. """
  391. Session = sessionmaker(bind=db.engine)
  392. session = Session()
  393. try:
  394. models = session.query(Models).all()
  395. if models:
  396. result = [
  397. {
  398. 'ModelID': model.ModelID,
  399. 'Model_name': model.Model_name,
  400. 'Model_type': model.Model_type,
  401. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  402. 'Description': model.Description,
  403. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  404. 'MAE': float(model.MAE) if model.MAE else None,
  405. 'CV_score': float(model.CV_score) if model.CV_score else None,
  406. 'RMSE': float(model.RMSE) if model.RMSE else None,
  407. 'Data_type': model.Data_type
  408. }
  409. for model in models
  410. ]
  411. return jsonify(result)
  412. else:
  413. return jsonify({'message': '未找到任何模型'}), 404
  414. except Exception as e:
  415. logger.error(f'获取所有模型信息失败: {str(e)}')
  416. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  417. finally:
  418. session.close()
  419. @bp.route('/model-parameters', methods=['GET'])
  420. def get_all_model_parameters():
  421. """
  422. 获取所有模型参数的API接口
  423. @return: JSON响应
  424. """
  425. Session = sessionmaker(bind=db.engine)
  426. session = Session()
  427. try:
  428. parameters = session.query(ModelParameters).all()
  429. if parameters:
  430. result = [
  431. {
  432. 'ParamID': param.ParamID,
  433. 'ModelID': param.ModelID,
  434. 'ParamName': param.ParamName,
  435. 'ParamValue': param.ParamValue
  436. }
  437. for param in parameters
  438. ]
  439. return jsonify(result)
  440. else:
  441. return jsonify({'message': '未找到任何参数'}), 404
  442. except Exception as e:
  443. logger.error(f'获取所有模型参数失败: {str(e)}')
  444. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  445. finally:
  446. session.close()
  447. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  448. def get_model_parameters(model_id):
  449. try:
  450. model = Models.query.filter_by(ModelID=model_id).first()
  451. if model:
  452. # 获取该模型的所有参数
  453. parameters = [
  454. {
  455. 'ParamID': param.ParamID,
  456. 'ParamName': param.ParamName,
  457. 'ParamValue': param.ParamValue
  458. }
  459. for param in model.parameters
  460. ]
  461. # 返回模型参数信息
  462. return jsonify({
  463. 'ModelID': model.ModelID,
  464. 'ModelName': model.ModelName,
  465. 'ModelType': model.ModelType,
  466. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  467. 'Description': model.Description,
  468. 'Parameters': parameters
  469. })
  470. else:
  471. return jsonify({'message': 'Model not found'}), 404
  472. except Exception as e:
  473. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  474. @bp.route('/train-model-async', methods=['POST'])
  475. def train_model_async():
  476. """
  477. 异步训练模型的API接口
  478. """
  479. try:
  480. data = request.get_json()
  481. # 从请求中获取参数
  482. model_type = data.get('model_type')
  483. model_name = data.get('model_name')
  484. model_description = data.get('model_description')
  485. data_type = data.get('data_type')
  486. dataset_id = data.get('dataset_id', None)
  487. # 验证必要参数
  488. if not all([model_type, model_name, data_type]):
  489. return jsonify({
  490. 'error': 'Missing required parameters'
  491. }), 400
  492. # 如果提供了dataset_id,验证数据集是否存在
  493. if dataset_id:
  494. Session = sessionmaker(bind=db.engine)
  495. session = Session()
  496. try:
  497. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  498. if not dataset:
  499. return jsonify({
  500. 'error': f'Dataset with ID {dataset_id} not found'
  501. }), 404
  502. finally:
  503. session.close()
  504. # 启动异步任务
  505. task = train_model_task.delay(
  506. model_type=model_type,
  507. model_name=model_name,
  508. model_description=model_description,
  509. data_type=data_type,
  510. dataset_id=dataset_id
  511. )
  512. # 返回任务ID
  513. return jsonify({
  514. 'task_id': task.id,
  515. 'message': 'Model training started'
  516. }), 202
  517. except Exception as e:
  518. logging.error('Failed to start async training task:', exc_info=True)
  519. return jsonify({
  520. 'error': str(e)
  521. }), 500
  522. @bp.route('/task-status/<task_id>', methods=['GET'])
  523. def get_task_status(task_id):
  524. """
  525. 获取异步任务状态的API接口
  526. """
  527. try:
  528. task = train_model_task.AsyncResult(task_id)
  529. if task.state == 'PENDING':
  530. response = {
  531. 'state': task.state,
  532. 'status': 'Task is waiting for execution'
  533. }
  534. elif task.state == 'FAILURE':
  535. response = {
  536. 'state': task.state,
  537. 'status': 'Task failed',
  538. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  539. }
  540. elif task.state == 'SUCCESS':
  541. response = {
  542. 'state': task.state,
  543. 'status': 'Task completed successfully',
  544. 'result': task.get()
  545. }
  546. else:
  547. response = {
  548. 'state': task.state,
  549. 'status': 'Task is in progress'
  550. }
  551. return jsonify(response), 200
  552. except Exception as e:
  553. return jsonify({
  554. 'error': str(e)
  555. }), 500
  556. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  557. def delete_model_route(model_id):
  558. # 将URL参数转换为布尔值
  559. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  560. # 调用原始函数
  561. return delete_model(model_id, delete_dataset=delete_dataset_param)
  562. def delete_model(model_id, delete_dataset=False):
  563. """
  564. 删除指定模型的API接口
  565. @param model_id: 要删除的模型ID
  566. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  567. @return: JSON响应
  568. """
  569. Session = sessionmaker(bind=db.engine)
  570. session = Session()
  571. try:
  572. # 查询模型信息
  573. model = session.query(Models).filter_by(ModelID=model_id).first()
  574. if not model:
  575. return jsonify({'error': '未找到指定模型'}), 404
  576. dataset_id = model.DatasetID
  577. # 1. 先删除模型记录
  578. session.delete(model)
  579. session.commit()
  580. # 2. 删除模型文件
  581. model_path = model.ModelFilePath
  582. try:
  583. if os.path.exists(model_path):
  584. os.remove(model_path)
  585. else:
  586. # 如果删除文件失败,回滚数据库操作
  587. session.rollback()
  588. logger.warning(f'模型文件不存在: {model_path}')
  589. except OSError as e:
  590. # 如果删除文件失败,回滚数据库操作
  591. session.rollback()
  592. logger.error(f'删除模型文件失败: {str(e)}')
  593. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  594. # 3. 如果需要删除关联的数据集
  595. if delete_dataset and dataset_id:
  596. try:
  597. dataset_response = delete_dataset_endpoint(dataset_id)
  598. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  599. # 如果删除数据集失败,回滚之前的操作
  600. session.rollback()
  601. return jsonify({
  602. 'error': '删除关联数据集失败',
  603. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  604. }), 500
  605. except Exception as e:
  606. session.rollback()
  607. logger.error(f'删除关联数据集失败: {str(e)}')
  608. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  609. response_data = {
  610. 'message': '模型删除成功',
  611. 'deleted_files': [model_path]
  612. }
  613. if delete_dataset:
  614. response_data['dataset_info'] = {
  615. 'dataset_id': dataset_id,
  616. 'message': '关联数据集已删除'
  617. }
  618. return jsonify(response_data), 200
  619. except Exception as e:
  620. session.rollback()
  621. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  622. return jsonify({'error': str(e)}), 500
  623. finally:
  624. session.close()
  625. # 添加一个新的API端点来清空指定数据集
  626. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  627. def clear_dataset(data_type):
  628. """
  629. 清空指定类型的数据集并递增计数
  630. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  631. @return: JSON响应
  632. """
  633. # 创建 sessionmaker 实例
  634. Session = sessionmaker(bind=db.engine)
  635. session = Session()
  636. try:
  637. # 根据数据集类型选择表
  638. if data_type == 'reduce':
  639. table = CurrentReduce
  640. table_name = 'current_reduce'
  641. elif data_type == 'reflux':
  642. table = CurrentReflux
  643. table_name = 'current_reflux'
  644. else:
  645. return jsonify({'error': '无效的数据集类型'}), 400
  646. # 清空表内容
  647. session.query(table).delete()
  648. # 重置自增主键计数器
  649. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  650. session.commit()
  651. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  652. except Exception as e:
  653. session.rollback()
  654. return jsonify({'error': str(e)}), 500
  655. finally:
  656. session.close()
  657. @bp.route('/update-threshold', methods=['POST'])
  658. def update_threshold():
  659. """
  660. 更新训练阈值的API接口
  661. @body_param threshold: 新的阈值值(整数)
  662. @body_param data_type: 数据类型 ('reduce' 或 'reflux')
  663. @return: JSON响应
  664. """
  665. try:
  666. data = request.get_json()
  667. new_threshold = data.get('threshold')
  668. data_type = data.get('data_type')
  669. # 验证新阈值
  670. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  671. return jsonify({
  672. 'error': '无效的阈值值,必须为正数'
  673. }), 400
  674. # 验证数据类型
  675. if data_type not in ['reduce', 'reflux']:
  676. return jsonify({
  677. 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
  678. }), 400
  679. # 更新当前应用的阈值配置
  680. if data_type == 'reduce':
  681. current_app.config['THRESHOLD_REDUCE'] = int(new_threshold)
  682. else: # reflux
  683. current_app.config['THRESHOLD_REFLUX'] = int(new_threshold)
  684. return jsonify({
  685. 'success': True,
  686. 'message': f'{data_type} 阈值已更新为 {new_threshold}',
  687. 'data_type': data_type,
  688. 'new_threshold': new_threshold
  689. })
  690. except Exception as e:
  691. logging.error(f"更新阈值失败: {str(e)}")
  692. return jsonify({
  693. 'error': f'更新阈值失败: {str(e)}'
  694. }), 500
  695. @bp.route('/get-threshold', methods=['GET'])
  696. def get_threshold():
  697. """
  698. 获取当前训练阈值的API接口
  699. @query_param data_type: 可选,数据类型 ('reduce' 或 'reflux')
  700. @return: JSON响应
  701. """
  702. try:
  703. data_type = request.args.get('data_type')
  704. if data_type and data_type not in ['reduce', 'reflux']:
  705. return jsonify({
  706. 'error': '无效的数据类型,必须为 "reduce" 或 "reflux"'
  707. }), 400
  708. response = {}
  709. if data_type == 'reduce' or data_type is None:
  710. response['reduce'] = {
  711. 'current_threshold': current_app.config['THRESHOLD_REDUCE'],
  712. 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REDUCE']
  713. }
  714. if data_type == 'reflux' or data_type is None:
  715. response['reflux'] = {
  716. 'current_threshold': current_app.config['THRESHOLD_REFLUX'],
  717. 'default_threshold': current_app.config['DEFAULT_THRESHOLD_REFLUX']
  718. }
  719. return jsonify(response)
  720. except Exception as e:
  721. logging.error(f"获取阈值失败: {str(e)}")
  722. return jsonify({
  723. 'error': f'获取阈值失败: {str(e)}'
  724. }), 500
  725. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  726. def set_current_dataset(data_type, dataset_id):
  727. """
  728. 将指定数据集设置为current数据集
  729. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  730. @param dataset_id: 要设置为current的数据集ID
  731. @return: JSON响应
  732. """
  733. Session = sessionmaker(bind=db.engine)
  734. session = Session()
  735. try:
  736. # 验证数据集存在且类型匹配
  737. dataset = session.query(Datasets)\
  738. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  739. .first()
  740. if not dataset:
  741. return jsonify({
  742. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  743. }), 404
  744. # 根据数据类型选择表
  745. if data_type == 'reduce':
  746. table = CurrentReduce
  747. table_name = 'current_reduce'
  748. elif data_type == 'reflux':
  749. table = CurrentReflux
  750. table_name = 'current_reflux'
  751. else:
  752. return jsonify({'error': '无效的数据集类型'}), 400
  753. # 清空current表
  754. session.query(table).delete()
  755. # 重置自增主键计数器
  756. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  757. # 从指定数据集复制数据到current表
  758. dataset_table_name = f"dataset_{dataset_id}"
  759. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  760. session.execute(copy_sql)
  761. session.commit()
  762. return jsonify({
  763. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  764. 'dataset_id': dataset_id,
  765. 'dataset_name': dataset.Dataset_name,
  766. 'row_count': dataset.Row_count
  767. }), 200
  768. except Exception as e:
  769. session.rollback()
  770. logger.error(f'设置current数据集失败: {str(e)}')
  771. return jsonify({'error': str(e)}), 500
  772. finally:
  773. session.close()
  774. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  775. def get_model_history(data_type):
  776. """
  777. 获取模型训练历史数据的API接口
  778. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  779. @return: JSON响应,包含时间序列的模型性能数据
  780. """
  781. Session = sessionmaker(bind=db.engine)
  782. session = Session()
  783. try:
  784. # 查询所有自动生成的数据集,按时间排序
  785. datasets = session.query(Datasets).filter(
  786. Datasets.Dataset_type == data_type,
  787. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  788. ).order_by(Datasets.Uploaded_at).all()
  789. history_data = []
  790. for dataset in datasets:
  791. # 查找对应的自动训练模型
  792. model = session.query(Models).filter(
  793. Models.DatasetID == dataset.Dataset_ID,
  794. Models.Model_name.like(f'auto_trained_{data_type}_%')
  795. ).first()
  796. if model and model.Performance_score is not None:
  797. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  798. created_at = model.Created_at.isoformat() if model.Created_at else None
  799. history_data.append({
  800. 'dataset_id': dataset.Dataset_ID,
  801. 'row_count': dataset.Row_count,
  802. 'model_id': model.ModelID,
  803. 'model_name': model.Model_name,
  804. 'performance_score': float(model.Performance_score),
  805. 'timestamp': created_at
  806. })
  807. # 按时间戳排序
  808. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  809. # 构建返回数据,分离各个指标序列便于前端绘图
  810. response_data = {
  811. 'data_type': data_type,
  812. 'timestamps': [item['timestamp'] for item in history_data],
  813. 'row_counts': [item['row_count'] for item in history_data],
  814. 'performance_scores': [item['performance_score'] for item in history_data],
  815. 'model_details': history_data # 保留完整数据供前端使用
  816. }
  817. return jsonify(response_data), 200
  818. except Exception as e:
  819. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  820. return jsonify({'error': str(e)}), 500
  821. finally:
  822. session.close()
  823. @bp.route('/batch-delete-datasets', methods=['POST'])
  824. def batch_delete_datasets():
  825. """
  826. 批量删除数据集的API接口
  827. @body_param dataset_ids: 要删除的数据集ID列表
  828. @return: JSON响应
  829. """
  830. try:
  831. data = request.get_json()
  832. dataset_ids = data.get('dataset_ids', [])
  833. if not dataset_ids:
  834. return jsonify({'error': '未提供数据集ID列表'}), 400
  835. results = {
  836. 'success': [],
  837. 'failed': [],
  838. 'protected': [] # 被模型使用的数据集
  839. }
  840. for dataset_id in dataset_ids:
  841. try:
  842. # 调用单个删除接口
  843. response = delete_dataset_endpoint(dataset_id)
  844. # 解析响应
  845. if response[1] == 200:
  846. results['success'].append(dataset_id)
  847. elif response[1] == 400 and 'models' in response[0].json:
  848. # 数据集被模型保护
  849. results['protected'].append({
  850. 'id': dataset_id,
  851. 'models': response[0].json['models']
  852. })
  853. else:
  854. results['failed'].append({
  855. 'id': dataset_id,
  856. 'reason': response[0].json.get('error', '删除失败')
  857. })
  858. except Exception as e:
  859. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  860. results['failed'].append({
  861. 'id': dataset_id,
  862. 'reason': str(e)
  863. })
  864. # 构建响应消息
  865. message = f"成功删除 {len(results['success'])} 个数据集"
  866. if results['protected']:
  867. message += f", {len(results['protected'])} 个数据集被保护"
  868. if results['failed']:
  869. message += f", {len(results['failed'])} 个数据集删除失败"
  870. return jsonify({
  871. 'message': message,
  872. 'results': results
  873. }), 200
  874. except Exception as e:
  875. logger.error(f'批量删除数据集失败: {str(e)}')
  876. return jsonify({'error': str(e)}), 500
  877. @bp.route('/batch-delete-models', methods=['POST'])
  878. def batch_delete_models():
  879. """
  880. 批量删除模型的API接口
  881. @body_param model_ids: 要删除的模型ID列表
  882. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  883. @return: JSON响应
  884. """
  885. try:
  886. data = request.get_json()
  887. model_ids = data.get('model_ids', [])
  888. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  889. if not model_ids:
  890. return jsonify({'error': '未提供模型ID列表'}), 400
  891. results = {
  892. 'success': [],
  893. 'failed': [],
  894. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  895. }
  896. for model_id in model_ids:
  897. try:
  898. # 调用单个删除接口
  899. response = delete_model(model_id, delete_dataset=delete_datasets)
  900. # 解析响应
  901. if response[1] == 200:
  902. results['success'].append(model_id)
  903. # 如果删除了关联数据集,记录数据集ID
  904. if 'dataset_info' in response[0].json:
  905. results['datasets_deleted'].append(
  906. response[0].json['dataset_info']['dataset_id']
  907. )
  908. else:
  909. results['failed'].append({
  910. 'id': model_id,
  911. 'reason': response[0].json.get('error', '删除失败')
  912. })
  913. except Exception as e:
  914. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  915. results['failed'].append({
  916. 'id': model_id,
  917. 'reason': str(e)
  918. })
  919. # 构建响应消息
  920. message = f"成功删除 {len(results['success'])} 个模型"
  921. if results['datasets_deleted']:
  922. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  923. if results['failed']:
  924. message += f", {len(results['failed'])} 个模型删除失败"
  925. return jsonify({
  926. 'message': message,
  927. 'results': results
  928. }), 200
  929. except Exception as e:
  930. logger.error(f'批量删除模型失败: {str(e)}')
  931. return jsonify({'error': str(e)}), 500
  932. @bp.route('/kriging_interpolation', methods=['POST'])
  933. def kriging_interpolation():
  934. try:
  935. data = request.get_json()
  936. required = ['file_name', 'emission_column', 'points']
  937. if not all(k in data for k in required):
  938. return jsonify({"error": "Missing parameters"}), 400
  939. # 添加坐标顺序验证
  940. points = data['points']
  941. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  942. return jsonify({"error": "Invalid points format"}), 400
  943. result = create_kriging(
  944. data['file_name'],
  945. data['emission_column'],
  946. data['points']
  947. )
  948. return jsonify(result)
  949. except Exception as e:
  950. return jsonify({"error": str(e)}), 500
  951. @bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
  952. def get_model_scatter_data(model_id):
  953. """
  954. 获取指定模型的散点图数据(真实值vs预测值)
  955. @param model_id: 模型ID
  956. @return: JSON响应,包含散点图数据
  957. """
  958. Session = sessionmaker(bind=db.engine)
  959. session = Session()
  960. try:
  961. # 查询模型信息
  962. model = session.query(Models).filter_by(ModelID=model_id).first()
  963. if not model:
  964. return jsonify({'error': '未找到指定模型'}), 404
  965. # 加载模型
  966. with open(model.ModelFilePath, 'rb') as f:
  967. ML_model = pickle.load(f)
  968. # 根据数据类型加载测试数据
  969. if model.Data_type == 'reflux':
  970. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  971. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  972. elif model.Data_type == 'reduce':
  973. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  974. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  975. else:
  976. return jsonify({'error': '不支持的数据类型'}), 400
  977. # 获取预测值
  978. y_pred = ML_model.predict(X_test)
  979. # 生成散点图数据
  980. scatter_data = [
  981. [float(true), float(pred)]
  982. for true, pred in zip(Y_test.iloc[:, 0], y_pred)
  983. ]
  984. # 计算R²分数
  985. r2 = r2_score(Y_test, y_pred)
  986. # 获取数据范围,用于绘制对角线
  987. y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
  988. y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
  989. return jsonify({
  990. 'scatter_data': scatter_data,
  991. 'r2_score': float(r2),
  992. 'y_range': [float(y_min), float(y_max)],
  993. 'model_name': model.Model_name,
  994. 'model_type': model.Model_type
  995. }), 200
  996. except Exception as e:
  997. logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
  998. return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
  999. finally:
  1000. session.close()