routes.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396
  1. import sqlite3
  2. from flask import Blueprint, request, jsonify, current_app
  3. from .model import predict, train_and_save_model, calculate_model_score
  4. import pandas as pd
  5. from . import db # 从 app 包导入 db 实例
  6. from sqlalchemy.engine.reflection import Inspector
  7. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  8. import os
  9. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  10. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  11. predict_to_Q, Q_to_t_ha
  12. from sqlalchemy.orm import sessionmaker
  13. import logging
  14. from sqlalchemy import text, select, MetaData, Table, func
  15. from .tasks import train_model_task
  16. # 配置日志
  17. logging.basicConfig(level=logging.DEBUG)
  18. logger = logging.getLogger(__name__)
  19. # 创建蓝图 (Blueprint),用于分离路由
  20. bp = Blueprint('routes', __name__)
  21. # 密码加密
  22. def hash_password(password):
  23. return generate_password_hash(password)
  24. def get_db():
  25. """ 获取数据库连接 """
  26. return sqlite3.connect(app.config['DATABASE'])
  27. # 添加一个新的辅助函数来检查数据集大小并触发训练
  28. def check_and_trigger_training(session, dataset_type, dataset_df):
  29. """
  30. 检查当前数据集大小是否跨越新的阈值点并触发训练
  31. Args:
  32. session: 数据库会话
  33. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  34. dataset_df: 数据集 DataFrame
  35. Returns:
  36. tuple: (是否触发训练, 任务ID)
  37. """
  38. try:
  39. # 根据数据集类型选择表
  40. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  41. # 获取当前记录数
  42. current_count = session.query(func.count()).select_from(table).scalar()
  43. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  44. new_records = len(dataset_df) # 需要从上层函数传入
  45. # 计算新增数据前的记录数
  46. previous_count = current_count - new_records
  47. # 设置阈值
  48. THRESHOLD = current_app.config['THRESHOLD']
  49. # 计算上一个阈值点(基于新增前的数据量)
  50. last_threshold = previous_count // THRESHOLD * THRESHOLD
  51. # 计算当前所在阈值点
  52. current_threshold = current_count // THRESHOLD * THRESHOLD
  53. # 检查是否跨越了新的阈值点
  54. if current_threshold > last_threshold and current_count >= THRESHOLD:
  55. # 触发异步训练任务
  56. task = train_model_task.delay(
  57. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  58. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  59. model_description=f'Auto trained model at {current_threshold} records threshold',
  60. data_type=dataset_type
  61. )
  62. return True, task.id
  63. return False, None
  64. except Exception as e:
  65. logging.error(f"检查并触发训练失败: {str(e)}")
  66. return False, None
  67. @bp.route('/upload-dataset', methods=['POST'])
  68. def upload_dataset():
  69. # 创建 session
  70. Session = sessionmaker(bind=db.engine)
  71. session = Session()
  72. try:
  73. if 'file' not in request.files:
  74. return jsonify({'error': 'No file part'}), 400
  75. file = request.files['file']
  76. if file.filename == '' or not allowed_file(file.filename):
  77. return jsonify({'error': 'No selected file or invalid file type'}), 400
  78. dataset_name = request.form.get('dataset_name')
  79. dataset_description = request.form.get('dataset_description', 'No description provided')
  80. dataset_type = request.form.get('dataset_type')
  81. if not dataset_type:
  82. return jsonify({'error': 'Dataset type is required'}), 400
  83. new_dataset = Datasets(
  84. Dataset_name=dataset_name,
  85. Dataset_description=dataset_description,
  86. Row_count=0,
  87. Status='Datasets_upgraded',
  88. Dataset_type=dataset_type
  89. )
  90. session.add(new_dataset)
  91. session.commit()
  92. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  93. upload_folder = current_app.config['UPLOAD_FOLDER']
  94. file_path = os.path.join(upload_folder, unique_filename)
  95. file.save(file_path)
  96. dataset_df = pd.read_excel(file_path)
  97. new_dataset.Row_count = len(dataset_df)
  98. new_dataset.Status = 'excel_file_saved success'
  99. session.commit()
  100. # 处理列名
  101. dataset_df = clean_column_names(dataset_df)
  102. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  103. column_types = infer_column_types(dataset_df)
  104. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  105. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  106. # 根据 dataset_type 决定插入到哪个已有表
  107. if dataset_type == 'reduce':
  108. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  109. elif dataset_type == 'reflux':
  110. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  111. session.commit()
  112. # 在完成数据插入后,检查是否需要触发训练
  113. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  114. response_data = {
  115. 'message': f'Dataset {dataset_name} uploaded successfully!',
  116. 'dataset_id': new_dataset.Dataset_ID,
  117. 'filename': unique_filename,
  118. 'training_triggered': training_triggered
  119. }
  120. if training_triggered:
  121. response_data['task_id'] = task_id
  122. response_data['message'] += ' Auto-training has been triggered.'
  123. return jsonify(response_data), 201
  124. except Exception as e:
  125. session.rollback()
  126. logging.error('Failed to process the dataset upload:', exc_info=True)
  127. return jsonify({'error': str(e)}), 500
  128. finally:
  129. # 确保 session 总是被关闭
  130. if session:
  131. session.close()
  132. @bp.route('/train-and-save-model', methods=['POST'])
  133. def train_and_save_model_endpoint():
  134. # 创建 sessionmaker 实例
  135. Session = sessionmaker(bind=db.engine)
  136. session = Session()
  137. data = request.get_json()
  138. # 从请求中解析参数
  139. model_type = data.get('model_type')
  140. model_name = data.get('model_name')
  141. model_description = data.get('model_description')
  142. data_type = data.get('data_type')
  143. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  144. try:
  145. # 调用训练和保存模型的函数
  146. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  147. model_id = result[1] if result else None
  148. # 计算模型评分
  149. if model_id:
  150. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  151. if model_info:
  152. score = calculate_model_score(model_info)
  153. # 更新模型评分
  154. model_info.Performance_score = score
  155. session.commit()
  156. result = {'model_id': model_id, 'model_score': score}
  157. # 返回成功响应
  158. return jsonify({
  159. 'message': 'Model trained and saved successfully',
  160. 'result': result
  161. }), 200
  162. except Exception as e:
  163. session.rollback()
  164. logging.error('Failed to process the model training:', exc_info=True)
  165. return jsonify({
  166. 'error': 'Failed to train and save model',
  167. 'message': str(e)
  168. }), 500
  169. finally:
  170. session.close()
  171. @bp.route('/predict', methods=['POST'])
  172. def predict_route():
  173. # 创建 sessionmaker 实例
  174. Session = sessionmaker(bind=db.engine)
  175. session = Session()
  176. try:
  177. data = request.get_json()
  178. model_id = data.get('model_id') # 提取模型名称
  179. parameters = data.get('parameters', {}) # 提取所有变量
  180. # 根据model_id获取模型Data_type
  181. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  182. if not model_info:
  183. return jsonify({'error': 'Model not found'}), 404
  184. data_type = model_info.Data_type
  185. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  186. # 如果为reduce,则不需要传入target_ph
  187. if data_type == 'reduce':
  188. # 获取传入的init_ph、target_ph参数
  189. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  190. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  191. # 从输入数据中删除'target_pH'列
  192. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  193. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  194. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  195. if data_type == 'reduce':
  196. predictions = predictions[0]
  197. # 将预测结果转换为Q
  198. Q = predict_to_Q(predictions, init_ph, target_ph)
  199. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  200. print(predictions)
  201. return jsonify({'result': predictions}), 200
  202. except Exception as e:
  203. logging.error('Failed to predict:', exc_info=True)
  204. return jsonify({'error': str(e)}), 400
  205. # 为指定模型计算评分Performance_score,需要提供model_id
  206. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  207. def score_model(model_id):
  208. # 创建 sessionmaker 实例
  209. Session = sessionmaker(bind=db.engine)
  210. session = Session()
  211. try:
  212. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  213. if not model_info:
  214. return jsonify({'error': 'Model not found'}), 404
  215. # 计算模型评分
  216. score = calculate_model_score(model_info)
  217. # 更新模型记录中的评分
  218. model_info.Performance_score = score
  219. session.commit()
  220. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  221. except Exception as e:
  222. logging.error('Failed to process the dataset upload:', exc_info=True)
  223. return jsonify({'error': str(e)}), 400
  224. finally:
  225. session.close()
  226. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  227. def delete_dataset_endpoint(dataset_id):
  228. """
  229. 删除数据集的API接口
  230. @param dataset_id: 要删除的数据集ID
  231. @return: JSON响应
  232. """
  233. # 创建 sessionmaker 实例
  234. Session = sessionmaker(bind=db.engine)
  235. session = Session()
  236. try:
  237. # 查询数据集
  238. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  239. if not dataset:
  240. return jsonify({'error': '未找到数据集'}), 404
  241. # 检查是否有模型使用了该数据集
  242. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  243. if models_using_dataset:
  244. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  245. return jsonify({
  246. 'error': '无法删除数据集,因为以下模型正在使用它',
  247. 'models': models_info
  248. }), 400
  249. # 删除Excel文件
  250. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  251. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  252. if os.path.exists(file_path):
  253. try:
  254. os.remove(file_path)
  255. except OSError as e:
  256. logger.error(f'删除文件失败: {str(e)}')
  257. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  258. # 删除数据表
  259. table_name = f"dataset_{dataset.Dataset_ID}"
  260. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  261. # 删除数据集记录
  262. session.delete(dataset)
  263. session.commit()
  264. return jsonify({
  265. 'message': '数据集删除成功',
  266. 'deleted_files': [filename]
  267. }), 200
  268. except Exception as e:
  269. session.rollback()
  270. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  271. return jsonify({'error': str(e)}), 500
  272. finally:
  273. session.close()
  274. @bp.route('/tables', methods=['GET'])
  275. def list_tables():
  276. engine = db.engine # 使用 db 实例的 engine
  277. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  278. table_names = inspector.get_table_names() # 获取所有表名
  279. return jsonify(table_names) # 以 JSON 形式返回表名列表
  280. @bp.route('/models/<int:model_id>', methods=['GET'])
  281. def get_model(model_id):
  282. try:
  283. model = Models.query.filter_by(ModelID=model_id).first()
  284. if model:
  285. return jsonify({
  286. 'ModelID': model.ModelID,
  287. 'ModelName': model.ModelName,
  288. 'ModelType': model.ModelType,
  289. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  290. 'Description': model.Description
  291. })
  292. else:
  293. return jsonify({'message': 'Model not found'}), 404
  294. except Exception as e:
  295. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  296. @bp.route('/models', methods=['GET'])
  297. def get_all_models():
  298. try:
  299. models = Models.query.all() # 获取所有模型数据
  300. if models:
  301. result = [
  302. {
  303. 'ModelID': model.ModelID,
  304. 'ModelName': model.ModelName,
  305. 'ModelType': model.ModelType,
  306. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  307. 'Description': model.Description
  308. }
  309. for model in models
  310. ]
  311. return jsonify(result)
  312. else:
  313. return jsonify({'message': 'No models found'}), 404
  314. except Exception as e:
  315. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  316. @bp.route('/model-parameters', methods=['GET'])
  317. def get_all_model_parameters():
  318. try:
  319. parameters = ModelParameters.query.all() # 获取所有参数数据
  320. if parameters:
  321. result = [
  322. {
  323. 'ParamID': param.ParamID,
  324. 'ModelID': param.ModelID,
  325. 'ParamName': param.ParamName,
  326. 'ParamValue': param.ParamValue
  327. }
  328. for param in parameters
  329. ]
  330. return jsonify(result)
  331. else:
  332. return jsonify({'message': 'No parameters found'}), 404
  333. except Exception as e:
  334. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  335. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  336. def get_model_parameters(model_id):
  337. try:
  338. model = Models.query.filter_by(ModelID=model_id).first()
  339. if model:
  340. # 获取该模型的所有参数
  341. parameters = [
  342. {
  343. 'ParamID': param.ParamID,
  344. 'ParamName': param.ParamName,
  345. 'ParamValue': param.ParamValue
  346. }
  347. for param in model.parameters
  348. ]
  349. # 返回模型参数信息
  350. return jsonify({
  351. 'ModelID': model.ModelID,
  352. 'ModelName': model.ModelName,
  353. 'ModelType': model.ModelType,
  354. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  355. 'Description': model.Description,
  356. 'Parameters': parameters
  357. })
  358. else:
  359. return jsonify({'message': 'Model not found'}), 404
  360. except Exception as e:
  361. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  362. # 定义添加数据库记录的 API 接口
  363. @bp.route('/add_item', methods=['POST'])
  364. def add_item():
  365. """
  366. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  367. 尝试将数据插入到指定的表中,并进行字段查重。
  368. :return:
  369. """
  370. try:
  371. # 确保请求体是 JSON 格式
  372. data = request.get_json()
  373. if not data:
  374. raise ValueError("No JSON data provided")
  375. table_name = data.get('table')
  376. item_data = data.get('item')
  377. if not table_name or not item_data:
  378. return jsonify({'error': 'Missing table name or item data'}), 400
  379. # 定义各个表的字段查重规则
  380. duplicate_check_rules = {
  381. 'users': ['email', 'username'],
  382. 'products': ['product_code'],
  383. 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
  384. 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
  385. # 其他表和规则
  386. }
  387. # 获取该表的查重字段
  388. duplicate_columns = duplicate_check_rules.get(table_name)
  389. if not duplicate_columns:
  390. return jsonify({'error': 'No duplicate check rule for this table'}), 400
  391. # 动态构建查询条件,逐一检查是否有重复数据
  392. condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
  393. duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
  394. result = db.session.execute(text(duplicate_query), item_data).fetchone()
  395. if result:
  396. return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
  397. # 动态构建 SQL 语句,进行插入操作
  398. columns = ', '.join(item_data.keys())
  399. placeholders = ', '.join([f":{key}" for key in item_data.keys()])
  400. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  401. # 直接执行插入操作,无需显式的事务管理
  402. db.session.execute(text(sql), item_data)
  403. # 提交事务
  404. db.session.commit()
  405. # 返回成功响应
  406. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  407. except ValueError as e:
  408. return jsonify({'error': str(e)}), 400
  409. except KeyError as e:
  410. return jsonify({'error': f'Missing data field: {e}'}), 400
  411. except sqlite3.IntegrityError as e:
  412. return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
  413. except sqlite3.Error as e:
  414. return jsonify({'error': '数据库错误', 'details': str(e)}), 500
  415. @bp.route('/delete_item', methods=['POST'])
  416. def delete_item():
  417. """
  418. 删除数据库记录的 API 接口
  419. """
  420. data = request.get_json()
  421. table_name = data.get('table')
  422. condition = data.get('condition')
  423. # 检查表名和条件是否提供
  424. if not table_name or not condition:
  425. return jsonify({
  426. "success": False,
  427. "message": "缺少表名或条件参数"
  428. }), 400
  429. # 尝试从条件字符串中解析键和值
  430. try:
  431. key, value = condition.split('=')
  432. key = key.strip() # 去除多余的空格
  433. value = value.strip().strip("'\"") # 去除多余的空格和引号
  434. except ValueError:
  435. return jsonify({
  436. "success": False,
  437. "message": "条件格式错误,应为 'key=value'"
  438. }), 400
  439. # 准备 SQL 删除语句
  440. sql = f"DELETE FROM {table_name} WHERE {key} = :value"
  441. try:
  442. # 使用 SQLAlchemy 执行删除
  443. with db.session.begin():
  444. result = db.session.execute(text(sql), {"value": value})
  445. # 检查是否有记录被删除
  446. if result.rowcount == 0:
  447. return jsonify({
  448. "success": False,
  449. "message": "未找到符合条件的记录"
  450. }), 404
  451. return jsonify({
  452. "success": True,
  453. "message": "记录删除成功"
  454. }), 200
  455. except Exception as e:
  456. return jsonify({
  457. "success": False,
  458. "message": f"删除失败: {e}"
  459. }), 500
  460. # 定义修改数据库记录的 API 接口
  461. @bp.route('/update_item', methods=['PUT'])
  462. def update_record():
  463. """
  464. 接收 JSON 格式的请求体,包含表名和更新的数据。
  465. 尝试更新指定的记录。
  466. """
  467. data = request.get_json()
  468. # 检查必要的数据是否提供
  469. if not data or 'table' not in data or 'item' not in data:
  470. return jsonify({
  471. "success": False,
  472. "message": "请求数据不完整"
  473. }), 400
  474. table_name = data['table']
  475. item = data['item']
  476. # 假设 item 的第一个键是 ID
  477. id_key = next(iter(item.keys())) # 获取第一个键
  478. record_id = item.get(id_key)
  479. if not record_id:
  480. return jsonify({
  481. "success": False,
  482. "message": "缺少记录 ID"
  483. }), 400
  484. # 获取更新的字段和值
  485. updates = {key: value for key, value in item.items() if key != id_key}
  486. if not updates:
  487. return jsonify({
  488. "success": False,
  489. "message": "没有提供需要更新的字段"
  490. }), 400
  491. # 动态构建 SQL
  492. set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
  493. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
  494. # 添加 ID 到参数
  495. updates['id_value'] = record_id
  496. try:
  497. # 使用 SQLAlchemy 执行更新
  498. with db.session.begin():
  499. result = db.session.execute(text(sql), updates)
  500. # 检查是否有更新的记录
  501. if result.rowcount == 0:
  502. return jsonify({
  503. "success": False,
  504. "message": "未找到要更新的记录"
  505. }), 404
  506. return jsonify({
  507. "success": True,
  508. "message": "数据更新成功"
  509. }), 200
  510. except Exception as e:
  511. # 捕获所有异常并返回
  512. return jsonify({
  513. "success": False,
  514. "message": f"更新失败: {str(e)}"
  515. }), 500
  516. # 定义查询数据库记录的 API 接口
  517. @bp.route('/search/record', methods=['GET'])
  518. def sql_search():
  519. """
  520. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  521. 尝试查询指定 ID 的记录并返回结果。
  522. :return:
  523. """
  524. try:
  525. data = request.get_json()
  526. # 表名
  527. sql_table = data['table']
  528. # 要搜索的 ID
  529. Id = data['id']
  530. # 连接到数据库
  531. cur = db.cursor()
  532. # 构造查询语句
  533. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  534. # 执行查询
  535. cur.execute(sql, (Id,))
  536. # 获取查询结果
  537. rows = cur.fetchall()
  538. column_names = [desc[0] for desc in cur.description]
  539. # 检查是否有结果
  540. if not rows:
  541. return jsonify({'error': '未查找到对应数据。'}), 400
  542. # 构造响应数据
  543. results = []
  544. for row in rows:
  545. result = {column_names[i]: row[i] for i in range(len(row))}
  546. results.append(result)
  547. # 关闭游标和数据库连接
  548. cur.close()
  549. db.close()
  550. # 返回 JSON 响应
  551. return jsonify(results), 200
  552. except sqlite3.Error as e:
  553. # 如果发生数据库错误,返回错误信息
  554. return jsonify({'error': str(e)}), 400
  555. except KeyError as e:
  556. # 如果请求数据中缺少必要的键,返回错误信息
  557. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  558. # 定义提供数据库列表,用于展示表格的 API 接口
  559. @bp.route('/table', methods=['POST'])
  560. def get_table():
  561. data = request.get_json()
  562. table_name = data.get('table')
  563. if not table_name:
  564. return jsonify({'error': '需要表名'}), 400
  565. try:
  566. # 创建 sessionmaker 实例
  567. Session = sessionmaker(bind=db.engine)
  568. session = Session()
  569. # 动态获取表的元数据
  570. metadata = MetaData()
  571. table = Table(table_name, metadata, autoload_with=db.engine)
  572. # 从数据库中查询所有记录
  573. query = select(table)
  574. result = session.execute(query).fetchall()
  575. # 将结果转换为列表字典形式
  576. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  577. # 获取列名
  578. headers = [column.name for column in table.columns]
  579. return jsonify(rows=rows, headers=headers), 200
  580. except Exception as e:
  581. return jsonify({'error': str(e)}), 400
  582. finally:
  583. # 关闭 session
  584. session.close()
  585. @bp.route('/train-model-async', methods=['POST'])
  586. def train_model_async():
  587. """
  588. 异步训练模型的API接口
  589. """
  590. try:
  591. data = request.get_json()
  592. # 从请求中获取参数
  593. model_type = data.get('model_type')
  594. model_name = data.get('model_name')
  595. model_description = data.get('model_description')
  596. data_type = data.get('data_type')
  597. dataset_id = data.get('dataset_id', None)
  598. # 验证必要参数
  599. if not all([model_type, model_name, data_type]):
  600. return jsonify({
  601. 'error': 'Missing required parameters'
  602. }), 400
  603. # 如果提供了dataset_id,验证数据集是否存在
  604. if dataset_id:
  605. Session = sessionmaker(bind=db.engine)
  606. session = Session()
  607. try:
  608. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  609. if not dataset:
  610. return jsonify({
  611. 'error': f'Dataset with ID {dataset_id} not found'
  612. }), 404
  613. finally:
  614. session.close()
  615. # 启动异步任务
  616. task = train_model_task.delay(
  617. model_type=model_type,
  618. model_name=model_name,
  619. model_description=model_description,
  620. data_type=data_type,
  621. dataset_id=dataset_id
  622. )
  623. # 返回任务ID
  624. return jsonify({
  625. 'task_id': task.id,
  626. 'message': 'Model training started'
  627. }), 202
  628. except Exception as e:
  629. logging.error('Failed to start async training task:', exc_info=True)
  630. return jsonify({
  631. 'error': str(e)
  632. }), 500
  633. @bp.route('/task-status/<task_id>', methods=['GET'])
  634. def get_task_status(task_id):
  635. """
  636. 获取异步任务状态的API接口
  637. """
  638. try:
  639. task = train_model_task.AsyncResult(task_id)
  640. if task.state == 'PENDING':
  641. response = {
  642. 'state': task.state,
  643. 'status': 'Task is waiting for execution'
  644. }
  645. elif task.state == 'FAILURE':
  646. response = {
  647. 'state': task.state,
  648. 'status': 'Task failed',
  649. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  650. }
  651. elif task.state == 'SUCCESS':
  652. response = {
  653. 'state': task.state,
  654. 'status': 'Task completed successfully',
  655. 'result': task.get()
  656. }
  657. else:
  658. response = {
  659. 'state': task.state,
  660. 'status': 'Task is in progress'
  661. }
  662. return jsonify(response), 200
  663. except Exception as e:
  664. return jsonify({
  665. 'error': str(e)
  666. }), 500
  667. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  668. def delete_model(model_id):
  669. """
  670. 删除指定模型的API接口
  671. @param model_id: 要删除的模型ID
  672. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  673. @return: JSON响应
  674. """
  675. Session = sessionmaker(bind=db.engine)
  676. session = Session()
  677. try:
  678. # 查询模型信息
  679. model = session.query(Models).filter_by(ModelID=model_id).first()
  680. if not model:
  681. return jsonify({'error': '未找到指定模型'}), 404
  682. dataset_id = model.DatasetID
  683. # 1. 先删除模型记录
  684. session.delete(model)
  685. session.commit()
  686. # 2. 删除模型文件
  687. model_file = f"rf_model_{model_id}.pkl"
  688. model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
  689. if os.path.exists(model_path):
  690. try:
  691. os.remove(model_path)
  692. except OSError as e:
  693. # 如果删除文件失败,回滚数据库操作
  694. session.rollback()
  695. logger.error(f'删除模型文件失败: {str(e)}')
  696. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  697. # 3. 如果需要删除关联的数据集
  698. delete_dataset = request.args.get('delete_dataset', 'false').lower() == 'true'
  699. if delete_dataset and dataset_id:
  700. try:
  701. dataset_response = delete_dataset_endpoint(dataset_id)
  702. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  703. # 如果删除数据集失败,回滚之前的操作
  704. session.rollback()
  705. return jsonify({
  706. 'error': '删除关联数据集失败',
  707. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  708. }), 500
  709. except Exception as e:
  710. session.rollback()
  711. logger.error(f'删除关联数据集失败: {str(e)}')
  712. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  713. response_data = {
  714. 'message': '模型删除成功',
  715. 'deleted_files': [model_file]
  716. }
  717. if delete_dataset:
  718. response_data['dataset_info'] = {
  719. 'dataset_id': dataset_id,
  720. 'message': '关联数据集已删除'
  721. }
  722. return jsonify(response_data), 200
  723. except Exception as e:
  724. session.rollback()
  725. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  726. return jsonify({'error': str(e)}), 500
  727. finally:
  728. session.close()
  729. # 添加一个新的API端点来清空指定数据集
  730. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  731. def clear_dataset(data_type):
  732. """
  733. 清空指定类型的数据集并递增计数
  734. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  735. @return: JSON响应
  736. """
  737. # 创建 sessionmaker 实例
  738. Session = sessionmaker(bind=db.engine)
  739. session = Session()
  740. try:
  741. # 根据数据集类型选择表
  742. if data_type == 'reduce':
  743. table = CurrentReduce
  744. table_name = 'current_reduce'
  745. elif data_type == 'reflux':
  746. table = CurrentReflux
  747. table_name = 'current_reflux'
  748. else:
  749. return jsonify({'error': '无效的数据集类型'}), 400
  750. # 清空表内容
  751. session.query(table).delete()
  752. # 重置自增主键计数器
  753. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  754. session.commit()
  755. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  756. except Exception as e:
  757. session.rollback()
  758. return jsonify({'error': str(e)}), 500
  759. finally:
  760. session.close()
  761. @bp.route('/login', methods=['POST'])
  762. def login_user():
  763. # 获取前端传来的数据
  764. data = request.get_json()
  765. name = data.get('name') # 用户名
  766. password = data.get('password') # 密码
  767. logger.info(f"Login request received: name={name}")
  768. # 检查用户名和密码是否为空
  769. if not name or not password:
  770. logger.warning("用户名和密码不能为空")
  771. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  772. try:
  773. # 查询数据库验证用户名
  774. query = "SELECT * FROM users WHERE name = :name"
  775. conn = get_db()
  776. user = conn.execute(query, {"name": name}).fetchone()
  777. if not user:
  778. logger.warning(f"用户名 '{name}' 不存在")
  779. return jsonify({"success": False, "message": "用户名不存在"}), 400
  780. # 获取数据库中存储的密码(假设密码是哈希存储的)
  781. stored_password = user[2] # 假设密码存储在数据库的第三列
  782. user_id = user[0] # 假设 id 存储在数据库的第一列
  783. # 校验密码是否正确
  784. if check_password_hash(stored_password, password):
  785. logger.info(f"User '{name}' logged in successfully.")
  786. return jsonify({
  787. "success": True,
  788. "message": "登录成功",
  789. "userId": user_id # 返回用户 ID
  790. })
  791. else:
  792. logger.warning(f"Invalid password for user '{name}'")
  793. return jsonify({"success": False, "message": "用户名或密码错误"}), 400
  794. except Exception as e:
  795. # 记录错误日志并返回错误信息
  796. logger.error(f"Error during login: {e}", exc_info=True)
  797. return jsonify({"success": False, "message": "登录失败"}), 500
  798. # 更新用户信息接口
  799. @bp.route('/update_user', methods=['POST'])
  800. def update_user():
  801. # 获取前端传来的数据
  802. data = request.get_json()
  803. # 打印收到的请求数据
  804. app.logger.info(f"Received data: {data}")
  805. user_id = data.get('userId') # 用户ID
  806. name = data.get('name') # 用户名
  807. old_password = data.get('oldPassword') # 旧密码
  808. new_password = data.get('newPassword') # 新密码
  809. logger.info(f"Update request received: user_id={user_id}, name={name}")
  810. # 校验传入的用户名和密码是否为空
  811. if not name or not old_password:
  812. logger.warning("用户名和旧密码不能为空")
  813. return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
  814. # 新密码和旧密码不能相同
  815. if new_password and old_password == new_password:
  816. logger.warning(f"新密码与旧密码相同:{name}")
  817. return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
  818. try:
  819. # 查询数据库验证用户ID
  820. query = "SELECT * FROM users WHERE id = :user_id"
  821. conn = get_db()
  822. user = conn.execute(query, {"user_id": user_id}).fetchone()
  823. if not user:
  824. logger.warning(f"用户ID '{user_id}' 不存在")
  825. return jsonify({"success": False, "message": "用户不存在"}), 400
  826. # 获取数据库中存储的密码(假设密码是哈希存储的)
  827. stored_password = user[2] # 假设密码存储在数据库的第三列
  828. # 校验旧密码是否正确
  829. if not check_password_hash(stored_password, old_password):
  830. logger.warning(f"旧密码错误:{name}")
  831. return jsonify({"success": False, "message": "旧密码错误"}), 400
  832. # 如果新密码非空,则更新新密码
  833. if new_password:
  834. hashed_new_password = hash_password(new_password)
  835. update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
  836. conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
  837. conn.commit()
  838. logger.info(f"User ID '{user_id}' password updated successfully.")
  839. # 如果用户名发生更改,则更新用户名
  840. if name != user[1]:
  841. update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
  842. conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
  843. conn.commit()
  844. logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
  845. return jsonify({"success": True, "message": "用户信息更新成功"})
  846. except Exception as e:
  847. # 记录错误日志并返回错误信息
  848. logger.error(f"Error updating user: {e}", exc_info=True)
  849. return jsonify({"success": False, "message": "更新失败"}), 500
  850. # 注册用户
  851. @bp.route('/register', methods=['POST'])
  852. def register_user():
  853. # 获取前端传来的数据
  854. data = request.get_json()
  855. name = data.get('name') # 用户名
  856. password = data.get('password') # 密码
  857. logger.info(f"Register request received: name={name}")
  858. # 检查用户名和密码是否为空
  859. if not name or not password:
  860. logger.warning("用户名和密码不能为空")
  861. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  862. # 动态获取数据库表的列名
  863. columns = get_column_names('users')
  864. logger.info(f"Database columns for 'users' table: {columns}")
  865. # 检查前端传来的数据是否包含数据库表中所有的必填字段
  866. for column in ['name', 'password']:
  867. if column not in columns:
  868. logger.error(f"缺少必填字段:{column}")
  869. return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
  870. # 对密码进行哈希处理
  871. hashed_password = hash_password(password)
  872. logger.info(f"Password hashed for user: {name}")
  873. # 插入到数据库
  874. try:
  875. # 检查用户是否已经存在
  876. query = "SELECT * FROM users WHERE name = :name"
  877. conn = get_db()
  878. user = conn.execute(query, {"name": name}).fetchone()
  879. if user:
  880. logger.warning(f"用户名 '{name}' 已存在")
  881. return jsonify({"success": False, "message": "用户名已存在"}), 400
  882. # 向数据库插入数据
  883. query = "INSERT INTO users (name, password) VALUES (:name, :password)"
  884. conn.execute(query, {"name": name, "password": hashed_password})
  885. conn.commit()
  886. logger.info(f"User '{name}' registered successfully.")
  887. return jsonify({"success": True, "message": "注册成功"})
  888. except Exception as e:
  889. # 记录错误日志并返回错误信息
  890. logger.error(f"Error registering user: {e}", exc_info=True)
  891. return jsonify({"success": False, "message": "注册失败"}), 500
  892. def get_column_names(table_name):
  893. """
  894. 动态获取数据库表的列名。
  895. """
  896. try:
  897. conn = get_db()
  898. query = f"PRAGMA table_info({table_name});"
  899. result = conn.execute(query).fetchall()
  900. conn.close()
  901. return [row[1] for row in result] # 第二列是列名
  902. except Exception as e:
  903. logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
  904. return []
  905. # 导出数据
  906. @bp.route('/export_data', methods=['GET'])
  907. def export_data():
  908. table_name = request.args.get('table')
  909. file_format = request.args.get('format', 'excel').lower()
  910. if not table_name:
  911. return jsonify({'error': '缺少表名参数'}), 400
  912. if not table_name.isidentifier():
  913. return jsonify({'error': '无效的表名'}), 400
  914. try:
  915. conn = get_db()
  916. query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
  917. table_exists = conn.execute(query, (table_name,)).fetchone()
  918. if not table_exists:
  919. return jsonify({'error': f"表 {table_name} 不存在"}), 404
  920. query = f"SELECT * FROM {table_name};"
  921. df = pd.read_sql(query, conn)
  922. output = BytesIO()
  923. if file_format == 'csv':
  924. df.to_csv(output, index=False, encoding='utf-8')
  925. output.seek(0)
  926. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
  927. elif file_format == 'excel':
  928. df.to_excel(output, index=False, engine='openpyxl')
  929. output.seek(0)
  930. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
  931. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  932. else:
  933. return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
  934. except Exception as e:
  935. logger.error(f"Error in export_data: {e}", exc_info=True)
  936. return jsonify({'error': str(e)}), 500
  937. # 导入数据接口
  938. @bp.route('/import_data', methods=['POST'])
  939. def import_data():
  940. logger.debug("Import data endpoint accessed.")
  941. if 'file' not in request.files:
  942. logger.error("No file in request.")
  943. return jsonify({'success': False, 'message': '文件缺失'}), 400
  944. file = request.files['file']
  945. table_name = request.form.get('table')
  946. if not table_name:
  947. logger.error("Missing table name parameter.")
  948. return jsonify({'success': False, 'message': '缺少表名参数'}), 400
  949. if file.filename == '':
  950. logger.error("No file selected.")
  951. return jsonify({'success': False, 'message': '未选择文件'}), 400
  952. try:
  953. # 保存文件到临时路径
  954. temp_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(file.filename))
  955. file.save(temp_path)
  956. logger.debug(f"File saved to temporary path: {temp_path}")
  957. # 根据文件类型读取文件
  958. if file.filename.endswith('.xlsx'):
  959. df = pd.read_excel(temp_path)
  960. elif file.filename.endswith('.csv'):
  961. df = pd.read_csv(temp_path)
  962. else:
  963. logger.error("Unsupported file format.")
  964. return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
  965. # 获取数据库列名
  966. db_columns = get_column_names(table_name)
  967. if 'id' in db_columns:
  968. db_columns.remove('id') # 假设 id 列是自增的,不需要处理
  969. if not set(db_columns).issubset(set(df.columns)):
  970. logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
  971. return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
  972. # 清洗数据并删除空值行
  973. df_cleaned = df[db_columns].dropna()
  974. # 统一数据类型,避免 int 和 float 合并问题
  975. df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
  976. # 获取现有的数据
  977. conn = get_db()
  978. with conn:
  979. existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
  980. # 查找重复数据
  981. duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
  982. # 如果有重复数据,删除它们
  983. df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
  984. logger.warning(f"Duplicate data detected and removed: {duplicates}")
  985. # 获取导入前后的数据量
  986. total_data = len(df_cleaned) + len(duplicates)
  987. new_data = len(df_cleaned)
  988. duplicate_data = len(duplicates)
  989. # 导入不重复的数据
  990. df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
  991. logger.debug(f"Imported {new_data} new records into the database.")
  992. # 删除临时文件
  993. os.remove(temp_path)
  994. logger.debug(f"Temporary file removed: {temp_path}")
  995. # 返回结果
  996. return jsonify({
  997. 'success': True,
  998. 'message': '数据导入成功',
  999. 'total_data': total_data,
  1000. 'new_data': new_data,
  1001. 'duplicate_data': duplicate_data
  1002. }), 200
  1003. except Exception as e:
  1004. logger.error(f"Import failed: {e}", exc_info=True)
  1005. return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
  1006. # 模板下载接口
  1007. @bp.route('/download_template', methods=['GET'])
  1008. def download_template():
  1009. """
  1010. 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
  1011. """
  1012. table_name = request.args.get('table')
  1013. if not table_name:
  1014. return jsonify({'error': '表名参数缺失'}), 400
  1015. columns = get_column_names(table_name)
  1016. if not columns:
  1017. return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
  1018. # 不包括 ID 列
  1019. if 'id' in columns:
  1020. columns.remove('id')
  1021. df = pd.DataFrame(columns=columns)
  1022. file_format = request.args.get('format', 'excel').lower()
  1023. try:
  1024. if file_format == 'csv':
  1025. output = BytesIO()
  1026. df.to_csv(output, index=False, encoding='utf-8')
  1027. output.seek(0)
  1028. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
  1029. mimetype='text/csv')
  1030. else:
  1031. output = BytesIO()
  1032. df.to_excel(output, index=False, engine='openpyxl')
  1033. output.seek(0)
  1034. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
  1035. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1036. except Exception as e:
  1037. logger.error(f"Failed to generate template: {e}", exc_info=True)
  1038. return jsonify({'error': '生成模板文件失败'}), 500
  1039. @bp.route('/update-threshold', methods=['POST'])
  1040. def update_threshold():
  1041. """
  1042. 更新训练阈值的API接口
  1043. @body_param threshold: 新的阈值值(整数)
  1044. @return: JSON响应
  1045. """
  1046. try:
  1047. data = request.get_json()
  1048. new_threshold = data.get('threshold')
  1049. # 验证新阈值
  1050. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  1051. return jsonify({
  1052. 'error': '无效的阈值值,必须为正数'
  1053. }), 400
  1054. # 更新当前应用的阈值配置
  1055. current_app.config['THRESHOLD'] = int(new_threshold)
  1056. return jsonify({
  1057. 'success': True,
  1058. 'message': f'阈值已更新为 {new_threshold}',
  1059. 'new_threshold': new_threshold
  1060. })
  1061. except Exception as e:
  1062. logging.error(f"更新阈值失败: {str(e)}")
  1063. return jsonify({
  1064. 'error': f'更新阈值失败: {str(e)}'
  1065. }), 500
  1066. @bp.route('/get-threshold', methods=['GET'])
  1067. def get_threshold():
  1068. """
  1069. 获取当前训练阈值的API接口
  1070. @return: JSON响应
  1071. """
  1072. try:
  1073. current_threshold = current_app.config['THRESHOLD']
  1074. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  1075. return jsonify({
  1076. 'current_threshold': current_threshold,
  1077. 'default_threshold': default_threshold
  1078. })
  1079. except Exception as e:
  1080. logging.error(f"获取阈值失败: {str(e)}")
  1081. return jsonify({
  1082. 'error': f'获取阈值失败: {str(e)}'
  1083. }), 500
  1084. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  1085. def set_current_dataset(data_type, dataset_id):
  1086. """
  1087. 将指定数据集设置为current数据集
  1088. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1089. @param dataset_id: 要设置为current的数据集ID
  1090. @return: JSON响应
  1091. """
  1092. Session = sessionmaker(bind=db.engine)
  1093. session = Session()
  1094. try:
  1095. # 验证数据集存在且类型匹配
  1096. dataset = session.query(Datasets)\
  1097. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  1098. .first()
  1099. if not dataset:
  1100. return jsonify({
  1101. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  1102. }), 404
  1103. # 根据数据类型选择表
  1104. if data_type == 'reduce':
  1105. table = CurrentReduce
  1106. table_name = 'current_reduce'
  1107. elif data_type == 'reflux':
  1108. table = CurrentReflux
  1109. table_name = 'current_reflux'
  1110. else:
  1111. return jsonify({'error': '无效的数据集类型'}), 400
  1112. # 清空current表
  1113. session.query(table).delete()
  1114. # 重置自增主键计数器
  1115. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  1116. # 从指定数据集复制数据到current表
  1117. dataset_table_name = f"dataset_{dataset_id}"
  1118. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  1119. session.execute(copy_sql)
  1120. session.commit()
  1121. return jsonify({
  1122. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  1123. 'dataset_id': dataset_id,
  1124. 'dataset_name': dataset.Dataset_name,
  1125. 'row_count': dataset.Row_count
  1126. }), 200
  1127. except Exception as e:
  1128. session.rollback()
  1129. logger.error(f'设置current数据集失败: {str(e)}')
  1130. return jsonify({'error': str(e)}), 500
  1131. finally:
  1132. session.close()