routes.py 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882
  1. import sqlite3
  2. import uuid
  3. from io import BytesIO
  4. from flask import Blueprint, request, jsonify, current_app, send_file, session
  5. from werkzeug.security import check_password_hash, generate_password_hash
  6. from werkzeug.utils import secure_filename, send_from_directory
  7. from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
  8. import pandas as pd
  9. from . import db # 从 app 包导入 db 实例
  10. from sqlalchemy.engine.reflection import Inspector
  11. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  12. import os
  13. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  14. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  15. predict_to_Q, Q_to_t_ha, create_kriging
  16. from sqlalchemy.orm import sessionmaker
  17. import logging
  18. from sqlalchemy import text, select, MetaData, Table, func
  19. from .tasks import train_model_task
  20. from datetime import datetime
  21. # 配置日志
  22. logging.basicConfig(level=logging.DEBUG)
  23. logger = logging.getLogger(__name__)
  24. # 创建蓝图 (Blueprint),用于分离路由
  25. bp = Blueprint('routes', __name__)
  26. # 密码加密
  27. def hash_password(password):
  28. return generate_password_hash(password)
  29. def get_db():
  30. """ 获取数据库连接 """
  31. return sqlite3.connect(current_app.config['DATABASE'])
  32. # 添加一个新的辅助函数来检查数据集大小并触发训练
  33. def check_and_trigger_training(session, dataset_type, dataset_df):
  34. """
  35. 检查当前数据集大小是否跨越新的阈值点并触发训练
  36. Args:
  37. session: 数据库会话
  38. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  39. dataset_df: 数据集 DataFrame
  40. Returns:
  41. tuple: (是否触发训练, 任务ID)
  42. """
  43. try:
  44. # 根据数据集类型选择表
  45. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  46. # 获取当前记录数
  47. current_count = session.query(func.count()).select_from(table).scalar()
  48. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  49. new_records = len(dataset_df) # 需要从上层函数传入
  50. # 计算新增数据前的记录数
  51. previous_count = current_count - new_records
  52. # 设置阈值
  53. THRESHOLD = current_app.config['THRESHOLD']
  54. # 计算上一个阈值点(基于新增前的数据量)
  55. last_threshold = previous_count // THRESHOLD * THRESHOLD
  56. # 计算当前所在阈值点
  57. current_threshold = current_count // THRESHOLD * THRESHOLD
  58. # 检查是否跨越了新的阈值点
  59. if current_threshold > last_threshold and current_count >= THRESHOLD:
  60. # 触发异步训练任务
  61. task = train_model_task.delay(
  62. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  63. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  64. model_description=f'Auto trained model at {current_threshold} records threshold',
  65. data_type=dataset_type
  66. )
  67. return True, task.id
  68. return False, None
  69. except Exception as e:
  70. logging.error(f"检查并触发训练失败: {str(e)}")
  71. return False, None
  72. @bp.route('/upload-dataset', methods=['POST'])
  73. def upload_dataset():
  74. # 创建 session
  75. Session = sessionmaker(bind=db.engine)
  76. session = Session()
  77. try:
  78. if 'file' not in request.files:
  79. return jsonify({'error': 'No file part'}), 400
  80. file = request.files['file']
  81. if file.filename == '' or not allowed_file(file.filename):
  82. return jsonify({'error': 'No selected file or invalid file type'}), 400
  83. dataset_name = request.form.get('dataset_name')
  84. dataset_description = request.form.get('dataset_description', 'No description provided')
  85. dataset_type = request.form.get('dataset_type')
  86. if not dataset_type:
  87. return jsonify({'error': 'Dataset type is required'}), 400
  88. new_dataset = Datasets(
  89. Dataset_name=dataset_name,
  90. Dataset_description=dataset_description,
  91. Row_count=0,
  92. Status='Datasets_upgraded',
  93. Dataset_type=dataset_type,
  94. Uploaded_at=datetime.now()
  95. )
  96. session.add(new_dataset)
  97. session.commit()
  98. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  99. upload_folder = current_app.config['UPLOAD_FOLDER']
  100. file_path = os.path.join(upload_folder, unique_filename)
  101. file.save(file_path)
  102. dataset_df = pd.read_excel(file_path)
  103. new_dataset.Row_count = len(dataset_df)
  104. new_dataset.Status = 'excel_file_saved success'
  105. session.commit()
  106. # 处理列名
  107. dataset_df = clean_column_names(dataset_df)
  108. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  109. column_types = infer_column_types(dataset_df)
  110. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  111. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  112. # 去除上传数据集内部的重复项
  113. original_count = len(dataset_df)
  114. dataset_df = dataset_df.drop_duplicates()
  115. duplicates_in_file = original_count - len(dataset_df)
  116. # 检查与现有数据的重复
  117. duplicates_with_existing = 0
  118. if dataset_type in ['reduce', 'reflux']:
  119. # 确定表名
  120. table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
  121. # 从表加载现有数据
  122. existing_data = pd.read_sql_table(table_name, session.bind)
  123. if 'id' in existing_data.columns:
  124. existing_data = existing_data.drop('id', axis=1)
  125. # 确定用于比较的列
  126. compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
  127. # 计算重复行数
  128. original_df_len = len(dataset_df)
  129. # 使用concat和drop_duplicates找出非重复行
  130. all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
  131. duplicates_mask = all_data.duplicated(keep='first')
  132. duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
  133. # 保留非重复行
  134. dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
  135. logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
  136. # 检查与测试集的重叠
  137. test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
  138. # 如果有与测试集重叠的数据,从数据集中移除
  139. if test_overlap_count > 0:
  140. # 创建一个布尔掩码,标记不在重叠索引中的行
  141. mask = ~dataset_df.index.isin(test_overlap_indices)
  142. # 应用掩码,只保留不重叠的行
  143. dataset_df = dataset_df[mask]
  144. logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
  145. # 根据 dataset_type 决定插入到哪个已有表
  146. if dataset_type == 'reduce':
  147. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  148. elif dataset_type == 'reflux':
  149. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  150. session.commit()
  151. # 在完成数据插入后,检查是否需要触发训练
  152. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  153. response_data = {
  154. 'message': f'数据集 {dataset_name} 上传成功!',
  155. 'dataset_id': new_dataset.Dataset_ID,
  156. 'filename': unique_filename,
  157. 'training_triggered': training_triggered,
  158. 'data_stats': {
  159. 'original_count': original_count,
  160. 'duplicates_in_file': duplicates_in_file,
  161. 'duplicates_with_existing': duplicates_with_existing,
  162. 'test_overlap_count': test_overlap_count,
  163. 'final_count': len(dataset_df)
  164. }
  165. }
  166. if training_triggered:
  167. response_data['task_id'] = task_id
  168. response_data['message'] += ' 自动训练已触发。'
  169. # 添加去重信息到消息中
  170. if duplicates_with_existing > 0:
  171. response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
  172. # 添加测试集重叠信息到消息中
  173. if test_overlap_count > 0:
  174. response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
  175. return jsonify(response_data), 201
  176. except Exception as e:
  177. session.rollback()
  178. logging.error('Failed to process the dataset upload:', exc_info=True)
  179. return jsonify({'error': str(e)}), 500
  180. finally:
  181. # 确保 session 总是被关闭
  182. if session:
  183. session.close()
  184. @bp.route('/train-and-save-model', methods=['POST'])
  185. def train_and_save_model_endpoint():
  186. # 创建 sessionmaker 实例
  187. Session = sessionmaker(bind=db.engine)
  188. session = Session()
  189. data = request.get_json()
  190. # 从请求中解析参数
  191. model_type = data.get('model_type')
  192. model_name = data.get('model_name')
  193. model_description = data.get('model_description')
  194. data_type = data.get('data_type')
  195. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  196. try:
  197. # 调用训练和保存模型的函数
  198. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  199. model_id = result[1] if result else None
  200. # 计算模型评分
  201. if model_id:
  202. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  203. if model_info:
  204. # 计算多种评分指标
  205. score_metrics = calculate_model_score(model_info)
  206. # 更新模型评分
  207. model_info.Performance_score = score_metrics['r2']
  208. # 添加新的评分指标到数据库
  209. model_info.MAE = score_metrics['mae']
  210. model_info.RMSE = score_metrics['rmse']
  211. # CV_score 已在 train_and_save_model 中设置,此处不再更新
  212. session.commit()
  213. result = {
  214. 'model_id': model_id,
  215. 'model_score': score_metrics['r2'],
  216. 'mae': score_metrics['mae'],
  217. 'rmse': score_metrics['rmse'],
  218. 'cv_score': result[3]
  219. }
  220. # 返回成功响应
  221. return jsonify({
  222. 'message': 'Model trained and saved successfully',
  223. 'result': result
  224. }), 200
  225. except Exception as e:
  226. session.rollback()
  227. logging.error('Failed to process the model training:', exc_info=True)
  228. return jsonify({
  229. 'error': 'Failed to train and save model',
  230. 'message': str(e)
  231. }), 500
  232. finally:
  233. session.close()
  234. @bp.route('/predict', methods=['POST'])
  235. def predict_route():
  236. # 创建 sessionmaker 实例
  237. Session = sessionmaker(bind=db.engine)
  238. session = Session()
  239. try:
  240. data = request.get_json()
  241. model_id = data.get('model_id') # 提取模型名称
  242. parameters = data.get('parameters', {}) # 提取所有变量
  243. # 根据model_id获取模型Data_type
  244. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  245. if not model_info:
  246. return jsonify({'error': 'Model not found'}), 404
  247. data_type = model_info.Data_type
  248. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  249. # 如果为reduce,则不需要传入target_ph
  250. if data_type == 'reduce':
  251. # 获取传入的init_ph、target_ph参数
  252. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  253. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  254. # 从输入数据中删除'target_pH'列
  255. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  256. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  257. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  258. if data_type == 'reduce':
  259. predictions = predictions[0]
  260. # 将预测结果转换为Q
  261. Q = predict_to_Q(predictions, init_ph, target_ph)
  262. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  263. print(predictions)
  264. return jsonify({'result': predictions}), 200
  265. except Exception as e:
  266. logging.error('Failed to predict:', exc_info=True)
  267. return jsonify({'error': str(e)}), 400
  268. # 为指定模型计算指标评分,需要提供model_id
  269. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  270. def score_model(model_id):
  271. # 创建 sessionmaker 实例
  272. Session = sessionmaker(bind=db.engine)
  273. session = Session()
  274. try:
  275. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  276. if not model_info:
  277. return jsonify({'error': 'Model not found'}), 404
  278. # 计算模型评分
  279. score_metrics = calculate_model_score(model_info)
  280. # 更新模型记录中的评分(不包括交叉验证得分)
  281. model_info.Performance_score = score_metrics['r2']
  282. model_info.MAE = score_metrics['mae']
  283. model_info.RMSE = score_metrics['rmse']
  284. session.commit()
  285. return jsonify({
  286. 'message': 'Model scored successfully',
  287. 'r2_score': score_metrics['r2'],
  288. 'mae': score_metrics['mae'],
  289. 'rmse': score_metrics['rmse'],
  290. }), 200
  291. except Exception as e:
  292. logging.error('Failed to process model scoring:', exc_info=True)
  293. return jsonify({'error': str(e)}), 400
  294. finally:
  295. session.close()
  296. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  297. def delete_dataset_endpoint(dataset_id):
  298. """
  299. 删除数据集的API接口
  300. @param dataset_id: 要删除的数据集ID
  301. @return: JSON响应
  302. """
  303. # 创建 sessionmaker 实例
  304. Session = sessionmaker(bind=db.engine)
  305. session = Session()
  306. try:
  307. # 查询数据集
  308. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  309. if not dataset:
  310. return jsonify({'error': '未找到数据集'}), 404
  311. # 检查是否有模型使用了该数据集
  312. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  313. if models_using_dataset:
  314. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  315. return jsonify({
  316. 'error': '无法删除数据集,因为以下模型正在使用它',
  317. 'models': models_info
  318. }), 400
  319. # 删除Excel文件
  320. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  321. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  322. if os.path.exists(file_path):
  323. try:
  324. os.remove(file_path)
  325. except OSError as e:
  326. logger.error(f'删除文件失败: {str(e)}')
  327. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  328. # 删除数据表
  329. table_name = f"dataset_{dataset.Dataset_ID}"
  330. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  331. # 删除数据集记录
  332. session.delete(dataset)
  333. session.commit()
  334. return jsonify({
  335. 'message': '数据集删除成功',
  336. 'deleted_files': [filename]
  337. }), 200
  338. except Exception as e:
  339. session.rollback()
  340. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  341. return jsonify({'error': str(e)}), 500
  342. finally:
  343. session.close()
  344. @bp.route('/tables', methods=['GET'])
  345. def list_tables():
  346. engine = db.engine # 使用 db 实例的 engine
  347. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  348. table_names = inspector.get_table_names() # 获取所有表名
  349. return jsonify(table_names) # 以 JSON 形式返回表名列表
  350. @bp.route('/models/<int:model_id>', methods=['GET'])
  351. def get_model(model_id):
  352. """
  353. 获取单个模型信息的API接口
  354. @param model_id: 模型ID
  355. @return: JSON响应
  356. """
  357. Session = sessionmaker(bind=db.engine)
  358. session = Session()
  359. try:
  360. model = session.query(Models).filter_by(ModelID=model_id).first()
  361. if model:
  362. return jsonify({
  363. 'ModelID': model.ModelID,
  364. 'Model_name': model.Model_name,
  365. 'Model_type': model.Model_type,
  366. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  367. 'Description': model.Description,
  368. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  369. 'Data_type': model.Data_type
  370. })
  371. else:
  372. return jsonify({'message': '未找到模型'}), 404
  373. except Exception as e:
  374. logger.error(f'获取模型信息失败: {str(e)}')
  375. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  376. finally:
  377. session.close()
  378. # 显示切换模型
  379. @bp.route('/models', methods=['GET'])
  380. def get_models():
  381. session = None
  382. try:
  383. # 创建 session
  384. Session = sessionmaker(bind=db.engine)
  385. session = Session()
  386. # 查询所有模型
  387. models = session.query(Models).all()
  388. logger.debug(f"Models found: {models}") # 打印查询的模型数据
  389. if not models:
  390. return jsonify({'message': 'No models found'}), 404
  391. # 将模型数据转换为字典列表
  392. models_list = [
  393. {
  394. 'ModelID': model.ModelID,
  395. 'ModelName': model.Model_name,
  396. 'ModelType': model.Model_type,
  397. 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  398. 'Description': model.Description,
  399. 'DatasetID': model.DatasetID,
  400. 'ModelFilePath': model.ModelFilePath,
  401. 'DataType': model.Data_type,
  402. 'PerformanceScore': model.Performance_score
  403. }
  404. for model in models
  405. ]
  406. return jsonify(models_list), 200
  407. except Exception as e:
  408. return jsonify({'error': str(e)}), 400
  409. finally:
  410. if session:
  411. session.close()
  412. @bp.route('/model-parameters', methods=['GET'])
  413. def get_all_model_parameters():
  414. """
  415. 获取所有模型参数的API接口
  416. @return: JSON响应
  417. """
  418. Session = sessionmaker(bind=db.engine)
  419. session = Session()
  420. try:
  421. parameters = session.query(ModelParameters).all()
  422. if parameters:
  423. result = [
  424. {
  425. 'ParamID': param.ParamID,
  426. 'ModelID': param.ModelID,
  427. 'ParamName': param.ParamName,
  428. 'ParamValue': param.ParamValue
  429. }
  430. for param in parameters
  431. ]
  432. return jsonify(result)
  433. else:
  434. return jsonify({'message': '未找到任何参数'}), 404
  435. except Exception as e:
  436. logger.error(f'获取所有模型参数失败: {str(e)}')
  437. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  438. finally:
  439. session.close()
  440. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  441. def get_model_parameters(model_id):
  442. try:
  443. model = Models.query.filter_by(ModelID=model_id).first()
  444. if model:
  445. # 获取该模型的所有参数
  446. parameters = [
  447. {
  448. 'ParamID': param.ParamID,
  449. 'ParamName': param.ParamName,
  450. 'ParamValue': param.ParamValue
  451. }
  452. for param in model.parameters
  453. ]
  454. # 返回模型参数信息
  455. return jsonify({
  456. 'ModelID': model.ModelID,
  457. 'ModelName': model.ModelName,
  458. 'ModelType': model.ModelType,
  459. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  460. 'Description': model.Description,
  461. 'Parameters': parameters
  462. })
  463. else:
  464. return jsonify({'message': 'Model not found'}), 404
  465. except Exception as e:
  466. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  467. # 定义添加数据库记录的 API 接口
  468. @bp.route('/add_item', methods=['POST'])
  469. def add_item():
  470. """
  471. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  472. 尝试将数据插入到指定的表中,并进行字段查重。
  473. :return:
  474. """
  475. try:
  476. # 确保请求体是 JSON 格式
  477. data = request.get_json()
  478. if not data:
  479. raise ValueError("No JSON data provided")
  480. table_name = data.get('table')
  481. item_data = data.get('item')
  482. if not table_name or not item_data:
  483. return jsonify({'error': 'Missing table name or item data'}), 400
  484. # 定义各个表的字段查重规则
  485. duplicate_check_rules = {
  486. 'users': ['email', 'username'],
  487. 'products': ['product_code'],
  488. 'current_reduce': ['Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
  489. 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
  490. # 其他表和规则
  491. }
  492. # 获取该表的查重字段
  493. duplicate_columns = duplicate_check_rules.get(table_name)
  494. if not duplicate_columns:
  495. return jsonify({'error': 'No duplicate check rule for this table'}), 400
  496. # 动态构建查询条件,逐一检查是否有重复数据
  497. condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
  498. duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
  499. result = db.session.execute(text(duplicate_query), item_data).fetchone()
  500. if result:
  501. return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
  502. # 动态构建 SQL 语句,进行插入操作
  503. columns = ', '.join(item_data.keys())
  504. placeholders = ', '.join([f":{key}" for key in item_data.keys()])
  505. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  506. # 直接执行插入操作,无需显式的事务管理
  507. db.session.execute(text(sql), item_data)
  508. # 提交事务
  509. db.session.commit()
  510. # 返回成功响应
  511. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  512. except ValueError as e:
  513. return jsonify({'error': str(e)}), 400
  514. except KeyError as e:
  515. return jsonify({'error': f'Missing data field: {e}'}), 400
  516. except sqlite3.IntegrityError as e:
  517. return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
  518. except sqlite3.Error as e:
  519. return jsonify({'error': '数据库错误', 'details': str(e)}), 500
  520. @bp.route('/delete_item', methods=['POST'])
  521. def delete_item():
  522. """
  523. 删除数据库记录的 API 接口
  524. """
  525. data = request.get_json()
  526. table_name = data.get('table')
  527. condition = data.get('condition')
  528. # 检查表名和条件是否提供
  529. if not table_name or not condition:
  530. return jsonify({
  531. "success": False,
  532. "message": "缺少表名或条件参数"
  533. }), 400
  534. # 尝试从条件字符串中解析键和值
  535. try:
  536. key, value = condition.split('=')
  537. key = key.strip() # 去除多余的空格
  538. value = value.strip().strip("'\"") # 去除多余的空格和引号
  539. except ValueError:
  540. return jsonify({
  541. "success": False,
  542. "message": "条件格式错误,应为 'key=value'"
  543. }), 400
  544. # 准备 SQL 删除语句
  545. sql = f"DELETE FROM {table_name} WHERE {key} = :value"
  546. try:
  547. # 使用 SQLAlchemy 执行删除
  548. with db.session.begin():
  549. result = db.session.execute(text(sql), {"value": value})
  550. # 检查是否有记录被删除
  551. if result.rowcount == 0:
  552. return jsonify({
  553. "success": False,
  554. "message": "未找到符合条件的记录"
  555. }), 404
  556. return jsonify({
  557. "success": True,
  558. "message": "记录删除成功"
  559. }), 200
  560. except Exception as e:
  561. return jsonify({
  562. "success": False,
  563. "message": f"删除失败: {e}"
  564. }), 500
  565. # 定义修改数据库记录的 API 接口
  566. @bp.route('/update_item', methods=['PUT'])
  567. def update_record():
  568. """
  569. 接收 JSON 格式的请求体,包含表名和更新的数据。
  570. 尝试更新指定的记录。
  571. """
  572. data = request.get_json()
  573. # 检查必要的数据是否提供
  574. if not data or 'table' not in data or 'item' not in data:
  575. return jsonify({
  576. "success": False,
  577. "message": "请求数据不完整"
  578. }), 400
  579. table_name = data['table']
  580. item = data['item']
  581. # 假设 item 的第一个键是 ID
  582. id_key = next(iter(item.keys())) # 获取第一个键
  583. record_id = item.get(id_key)
  584. if not record_id:
  585. return jsonify({
  586. "success": False,
  587. "message": "缺少记录 ID"
  588. }), 400
  589. # 获取更新的字段和值
  590. updates = {key: value for key, value in item.items() if key != id_key}
  591. if not updates:
  592. return jsonify({
  593. "success": False,
  594. "message": "没有提供需要更新的字段"
  595. }), 400
  596. # 动态构建 SQL
  597. set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
  598. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
  599. # 添加 ID 到参数
  600. updates['id_value'] = record_id
  601. try:
  602. # 使用 SQLAlchemy 执行更新
  603. with db.session.begin():
  604. result = db.session.execute(text(sql), updates)
  605. # 检查是否有更新的记录
  606. if result.rowcount == 0:
  607. return jsonify({
  608. "success": False,
  609. "message": "未找到要更新的记录"
  610. }), 404
  611. return jsonify({
  612. "success": True,
  613. "message": "数据更新成功"
  614. }), 200
  615. except Exception as e:
  616. # 捕获所有异常并返回
  617. return jsonify({
  618. "success": False,
  619. "message": f"更新失败: {str(e)}"
  620. }), 500
  621. # 定义查询数据库记录的 API 接口
  622. @bp.route('/search/record', methods=['GET'])
  623. def sql_search():
  624. """
  625. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  626. 尝试查询指定 ID 的记录并返回结果。
  627. :return:
  628. """
  629. try:
  630. data = request.get_json()
  631. # 表名
  632. sql_table = data['table']
  633. # 要搜索的 ID
  634. Id = data['id']
  635. # 连接到数据库
  636. cur = db.cursor()
  637. # 构造查询语句
  638. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  639. # 执行查询
  640. cur.execute(sql, (Id,))
  641. # 获取查询结果
  642. rows = cur.fetchall()
  643. column_names = [desc[0] for desc in cur.description]
  644. # 检查是否有结果
  645. if not rows:
  646. return jsonify({'error': '未查找到对应数据。'}), 400
  647. # 构造响应数据
  648. results = []
  649. for row in rows:
  650. result = {column_names[i]: row[i] for i in range(len(row))}
  651. results.append(result)
  652. # 关闭游标和数据库连接
  653. cur.close()
  654. db.close()
  655. # 返回 JSON 响应
  656. return jsonify(results), 200
  657. except sqlite3.Error as e:
  658. # 如果发生数据库错误,返回错误信息
  659. return jsonify({'error': str(e)}), 400
  660. except KeyError as e:
  661. # 如果请求数据中缺少必要的键,返回错误信息
  662. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  663. # 定义提供数据库列表,用于展示表格的 API 接口
  664. @bp.route('/table', methods=['POST'])
  665. def get_table():
  666. data = request.get_json()
  667. table_name = data.get('table')
  668. if not table_name:
  669. return jsonify({'error': '需要表名'}), 400
  670. try:
  671. # 创建 sessionmaker 实例
  672. Session = sessionmaker(bind=db.engine)
  673. session = Session()
  674. # 动态获取表的元数据
  675. metadata = MetaData()
  676. table = Table(table_name, metadata, autoload_with=db.engine)
  677. # 从数据库中查询所有记录
  678. query = select(table)
  679. result = session.execute(query).fetchall()
  680. # 将结果转换为列表字典形式
  681. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  682. # 获取列名
  683. headers = [column.name for column in table.columns]
  684. return jsonify(rows=rows, headers=headers), 200
  685. except Exception as e:
  686. return jsonify({'error': str(e)}), 400
  687. finally:
  688. # 关闭 session
  689. session.close()
  690. @bp.route('/train-model-async', methods=['POST'])
  691. def train_model_async():
  692. """
  693. 异步训练模型的API接口
  694. """
  695. try:
  696. data = request.get_json()
  697. # 从请求中获取参数
  698. model_type = data.get('model_type')
  699. model_name = data.get('model_name')
  700. model_description = data.get('model_description')
  701. data_type = data.get('data_type')
  702. dataset_id = data.get('dataset_id', None)
  703. # 验证必要参数
  704. if not all([model_type, model_name, data_type]):
  705. return jsonify({
  706. 'error': 'Missing required parameters'
  707. }), 400
  708. # 如果提供了dataset_id,验证数据集是否存在
  709. if dataset_id:
  710. Session = sessionmaker(bind=db.engine)
  711. session = Session()
  712. try:
  713. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  714. if not dataset:
  715. return jsonify({
  716. 'error': f'Dataset with ID {dataset_id} not found'
  717. }), 404
  718. finally:
  719. session.close()
  720. # 启动异步任务
  721. task = train_model_task.delay(
  722. model_type=model_type,
  723. model_name=model_name,
  724. model_description=model_description,
  725. data_type=data_type,
  726. dataset_id=dataset_id
  727. )
  728. # 返回任务ID
  729. return jsonify({
  730. 'task_id': task.id,
  731. 'message': 'Model training started'
  732. }), 202
  733. except Exception as e:
  734. logging.error('Failed to start async training task:', exc_info=True)
  735. return jsonify({
  736. 'error': str(e)
  737. }), 500
  738. @bp.route('/task-status/<task_id>', methods=['GET'])
  739. def get_task_status(task_id):
  740. """
  741. 获取异步任务状态的API接口
  742. """
  743. try:
  744. task = train_model_task.AsyncResult(task_id)
  745. if task.state == 'PENDING':
  746. response = {
  747. 'state': task.state,
  748. 'status': 'Task is waiting for execution'
  749. }
  750. elif task.state == 'FAILURE':
  751. response = {
  752. 'state': task.state,
  753. 'status': 'Task failed',
  754. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  755. }
  756. elif task.state == 'SUCCESS':
  757. response = {
  758. 'state': task.state,
  759. 'status': 'Task completed successfully',
  760. 'result': task.get()
  761. }
  762. else:
  763. response = {
  764. 'state': task.state,
  765. 'status': 'Task is in progress'
  766. }
  767. return jsonify(response), 200
  768. except Exception as e:
  769. return jsonify({
  770. 'error': str(e)
  771. }), 500
  772. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  773. def delete_model_route(model_id):
  774. # 将URL参数转换为布尔值
  775. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  776. # 调用原始函数
  777. return delete_model(model_id, delete_dataset=delete_dataset_param)
  778. def delete_model(model_id, delete_dataset=False):
  779. """
  780. 删除指定模型的API接口
  781. @param model_id: 要删除的模型ID
  782. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  783. @return: JSON响应
  784. """
  785. Session = sessionmaker(bind=db.engine)
  786. session = Session()
  787. try:
  788. # 查询模型信息
  789. model = session.query(Models).filter_by(ModelID=model_id).first()
  790. if not model:
  791. return jsonify({'error': '未找到指定模型'}), 404
  792. dataset_id = model.DatasetID
  793. # 1. 先删除模型记录
  794. session.delete(model)
  795. session.commit()
  796. # 2. 删除模型文件
  797. model_file = f"rf_model_{model_id}.pkl"
  798. model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
  799. if os.path.exists(model_path):
  800. try:
  801. os.remove(model_path)
  802. except OSError as e:
  803. # 如果删除文件失败,回滚数据库操作
  804. session.rollback()
  805. logger.error(f'删除模型文件失败: {str(e)}')
  806. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  807. # 3. 如果需要删除关联的数据集
  808. if delete_dataset and dataset_id:
  809. try:
  810. dataset_response = delete_dataset_endpoint(dataset_id)
  811. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  812. # 如果删除数据集失败,回滚之前的操作
  813. session.rollback()
  814. return jsonify({
  815. 'error': '删除关联数据集失败',
  816. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  817. }), 500
  818. except Exception as e:
  819. session.rollback()
  820. logger.error(f'删除关联数据集失败: {str(e)}')
  821. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  822. response_data = {
  823. 'message': '模型删除成功',
  824. 'deleted_files': [model_file]
  825. }
  826. if delete_dataset:
  827. response_data['dataset_info'] = {
  828. 'dataset_id': dataset_id,
  829. 'message': '关联数据集已删除'
  830. }
  831. return jsonify(response_data), 200
  832. except Exception as e:
  833. session.rollback()
  834. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  835. return jsonify({'error': str(e)}), 500
  836. finally:
  837. session.close()
  838. # 添加一个新的API端点来清空指定数据集
  839. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  840. def clear_dataset(data_type):
  841. """
  842. 清空指定类型的数据集并递增计数
  843. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  844. @return: JSON响应
  845. """
  846. # 创建 sessionmaker 实例
  847. Session = sessionmaker(bind=db.engine)
  848. session = Session()
  849. try:
  850. # 根据数据集类型选择表
  851. if data_type == 'reduce':
  852. table = CurrentReduce
  853. table_name = 'current_reduce'
  854. elif data_type == 'reflux':
  855. table = CurrentReflux
  856. table_name = 'current_reflux'
  857. else:
  858. return jsonify({'error': '无效的数据集类型'}), 400
  859. # 清空表内容
  860. session.query(table).delete()
  861. # 重置自增主键计数器
  862. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  863. session.commit()
  864. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  865. except Exception as e:
  866. session.rollback()
  867. return jsonify({'error': str(e)}), 500
  868. finally:
  869. session.close()
  870. # 修改了一下登录·接口
  871. @bp.route('/login', methods=['POST'])
  872. def login_user():
  873. data = request.get_json()
  874. logger.debug(f"Received login data: {data}") # 增加调试日志
  875. name = data.get('name')
  876. password = data.get('password')
  877. logger.info(f"Login request received for user: {name}")
  878. if not isinstance(name, str) or not isinstance(password, str):
  879. logger.warning("Username and password must be strings")
  880. return jsonify({"success": False, "message": "用户名和密码必须为字符串"}), 400
  881. if not name or not password:
  882. logger.warning("Username and password cannot be empty")
  883. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  884. query = "SELECT id, name, password FROM users WHERE name = ?"
  885. try:
  886. with get_db() as conn:
  887. user = conn.execute(query, (name,)).fetchone()
  888. if not user:
  889. logger.warning(f"User '{name}' does not exist")
  890. return jsonify({"success": False, "message": "用户名不存在"}), 400
  891. stored_password = user[2] # 假设 'password' 是第三个字段
  892. user_id = user[0] # 假设 'id' 是第一个字段
  893. if check_password_hash(stored_password, password):
  894. session['name'] = name
  895. logger.info(f"User '{name}' logged in successfully.")
  896. return jsonify({
  897. "success": True,
  898. "message": "登录成功",
  899. "userId": user_id,
  900. "name": name
  901. })
  902. else:
  903. logger.warning(f"Incorrect password for user '{name}'")
  904. return jsonify({"success": False, "message": "用户名或密码错误"}), 400
  905. except sqlite3.DatabaseError as db_err:
  906. logger.error(f"Database error during login process: {db_err}", exc_info=True)
  907. return jsonify({"success": False, "message": "数据库错误"}), 500
  908. except Exception as e:
  909. logger.error(f"Unexpected error during login process: {e}", exc_info=True)
  910. return jsonify({"success": False, "message": "登录失败"}), 500
  911. # 更新用户信息接口
  912. @bp.route('/update_user', methods=['POST'])
  913. def update_user():
  914. # 获取前端传来的数据
  915. data = request.get_json()
  916. # 打印收到的请求数据
  917. current_app.logger.info(f"Received data: {data}")
  918. user_id = data.get('userId') # 用户ID
  919. name = data.get('name') # 用户名
  920. old_password = data.get('oldPassword') # 旧密码
  921. new_password = data.get('newPassword') # 新密码
  922. logger.info(f"Update request received: user_id={user_id}, name={name}")
  923. # 校验传入的用户名和密码是否为空
  924. if not name or not old_password:
  925. logger.warning("用户名和旧密码不能为空")
  926. return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
  927. # 新密码和旧密码不能相同
  928. if new_password and old_password == new_password:
  929. logger.warning(f"新密码与旧密码相同:{name}")
  930. return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
  931. try:
  932. # 查询数据库验证用户ID
  933. query = "SELECT * FROM users WHERE id = :user_id"
  934. conn = get_db()
  935. user = conn.execute(query, {"user_id": user_id}).fetchone()
  936. if not user:
  937. logger.warning(f"用户ID '{user_id}' 不存在")
  938. return jsonify({"success": False, "message": "用户不存在"}), 400
  939. # 获取数据库中存储的密码(假设密码是哈希存储的)
  940. stored_password = user[2] # 假设密码存储在数据库的第三列
  941. # 校验旧密码是否正确
  942. if not check_password_hash(stored_password, old_password):
  943. logger.warning(f"旧密码错误:{name}")
  944. return jsonify({"success": False, "message": "旧密码错误"}), 400
  945. # 如果新密码非空,则更新新密码
  946. if new_password:
  947. hashed_new_password = hash_password(new_password)
  948. update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
  949. conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
  950. conn.commit()
  951. logger.info(f"User ID '{user_id}' password updated successfully.")
  952. # 如果用户名发生更改,则更新用户名
  953. if name != user[1]:
  954. update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
  955. conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
  956. conn.commit()
  957. logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
  958. return jsonify({"success": True, "message": "用户信息更新成功"})
  959. except Exception as e:
  960. # 记录错误日志并返回错误信息
  961. logger.error(f"Error updating user: {e}", exc_info=True)
  962. return jsonify({"success": False, "message": "更新失败"}), 500
  963. # 注册用户
  964. @bp.route('/register', methods=['POST'])
  965. def register_user():
  966. # 获取前端传来的数据
  967. data = request.get_json()
  968. name = data.get('name') # 用户名
  969. password = data.get('password') # 密码
  970. logger.info(f"Register request received: name={name}")
  971. # 检查用户名和密码是否为空
  972. if not name or not password:
  973. logger.warning("用户名和密码不能为空")
  974. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  975. # 动态获取数据库表的列名
  976. columns = get_column_names('users')
  977. logger.info(f"Database columns for 'users' table: {columns}")
  978. # 检查前端传来的数据是否包含数据库表中所有的必填字段
  979. for column in ['name', 'password']:
  980. if column not in columns:
  981. logger.error(f"缺少必填字段:{column}")
  982. return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
  983. # 对密码进行哈希处理
  984. hashed_password = hash_password(password)
  985. logger.info(f"Password hashed for user: {name}")
  986. # 插入到数据库
  987. try:
  988. # 检查用户是否已经存在
  989. query = "SELECT * FROM users WHERE name = :name"
  990. conn = get_db()
  991. user = conn.execute(query, {"name": name}).fetchone()
  992. if user:
  993. logger.warning(f"用户名 '{name}' 已存在")
  994. return jsonify({"success": False, "message": "用户名已存在"}), 400
  995. # 向数据库插入数据
  996. query = "INSERT INTO users (name, password) VALUES (:name, :password)"
  997. conn.execute(query, {"name": name, "password": hashed_password})
  998. conn.commit()
  999. logger.info(f"User '{name}' registered successfully.")
  1000. return jsonify({"success": True, "message": "注册成功"})
  1001. except Exception as e:
  1002. # 记录错误日志并返回错误信息
  1003. logger.error(f"Error registering user: {e}", exc_info=True)
  1004. return jsonify({"success": False, "message": "注册失败"}), 500
  1005. def get_column_names(table_name):
  1006. """
  1007. 动态获取数据库表的列名。
  1008. """
  1009. try:
  1010. conn = get_db()
  1011. query = f"PRAGMA table_info({table_name});"
  1012. result = conn.execute(query).fetchall()
  1013. conn.close()
  1014. return [row[1] for row in result] # 第二列是列名
  1015. except Exception as e:
  1016. logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
  1017. return []
  1018. # 导出数据
  1019. @bp.route('/export_data', methods=['GET'])
  1020. def export_data():
  1021. table_name = request.args.get('table')
  1022. file_format = request.args.get('format', 'excel').lower()
  1023. if not table_name:
  1024. return jsonify({'error': '缺少表名参数'}), 400
  1025. if not table_name.isidentifier():
  1026. return jsonify({'error': '无效的表名'}), 400
  1027. try:
  1028. conn = get_db()
  1029. query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
  1030. table_exists = conn.execute(query, (table_name,)).fetchone()
  1031. if not table_exists:
  1032. return jsonify({'error': f"表 {table_name} 不存在"}), 404
  1033. query = f"SELECT * FROM {table_name};"
  1034. df = pd.read_sql(query, conn)
  1035. output = BytesIO()
  1036. if file_format == 'csv':
  1037. df.to_csv(output, index=False, encoding='utf-8')
  1038. output.seek(0)
  1039. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
  1040. elif file_format == 'excel':
  1041. df.to_excel(output, index=False, engine='openpyxl')
  1042. output.seek(0)
  1043. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
  1044. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1045. else:
  1046. return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
  1047. except Exception as e:
  1048. logger.error(f"Error in export_data: {e}", exc_info=True)
  1049. return jsonify({'error': str(e)}), 500
  1050. # 导入数据接口
  1051. @bp.route('/import_data', methods=['POST'])
  1052. def import_data():
  1053. logger.debug("Import data endpoint accessed.")
  1054. if 'file' not in request.files:
  1055. logger.error("No file in request.")
  1056. return jsonify({'success': False, 'message': '文件缺失'}), 400
  1057. file = request.files['file']
  1058. table_name = request.form.get('table')
  1059. if not table_name:
  1060. logger.error("Missing table name parameter.")
  1061. return jsonify({'success': False, 'message': '缺少表名参数'}), 400
  1062. if file.filename == '':
  1063. logger.error("No file selected.")
  1064. return jsonify({'success': False, 'message': '未选择文件'}), 400
  1065. try:
  1066. # 保存文件到临时路径
  1067. temp_path = os.path.join(current_app.config['UPLOAD_FOLDER'], secure_filename(file.filename))
  1068. file.save(temp_path)
  1069. logger.debug(f"File saved to temporary path: {temp_path}")
  1070. # 根据文件类型读取文件
  1071. if file.filename.endswith('.xlsx'):
  1072. df = pd.read_excel(temp_path)
  1073. elif file.filename.endswith('.csv'):
  1074. df = pd.read_csv(temp_path)
  1075. else:
  1076. logger.error("Unsupported file format.")
  1077. return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
  1078. # 获取数据库列名
  1079. db_columns = get_column_names(table_name)
  1080. if 'id' in db_columns:
  1081. db_columns.remove('id') # 假设 id 列是自增的,不需要处理
  1082. if not set(db_columns).issubset(set(df.columns)):
  1083. logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
  1084. return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
  1085. # 清洗数据并删除空值行
  1086. df_cleaned = df[db_columns].dropna()
  1087. # 统一数据类型,避免 int 和 float 合并问题
  1088. df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
  1089. # 获取现有的数据
  1090. conn = get_db()
  1091. with conn:
  1092. existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
  1093. # 查找重复数据
  1094. duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
  1095. # 如果有重复数据,删除它们
  1096. df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
  1097. logger.warning(f"Duplicate data detected and removed: {duplicates}")
  1098. # 获取导入前后的数据量
  1099. total_data = len(df_cleaned) + len(duplicates)
  1100. new_data = len(df_cleaned)
  1101. duplicate_data = len(duplicates)
  1102. # 导入不重复的数据
  1103. df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
  1104. logger.debug(f"Imported {new_data} new records into the database.")
  1105. # 删除临时文件
  1106. os.remove(temp_path)
  1107. logger.debug(f"Temporary file removed: {temp_path}")
  1108. # 返回结果
  1109. return jsonify({
  1110. 'success': True,
  1111. 'message': '数据导入成功',
  1112. 'total_data': total_data,
  1113. 'new_data': new_data,
  1114. 'duplicate_data': duplicate_data
  1115. }), 200
  1116. except Exception as e:
  1117. logger.error(f"Import failed: {e}", exc_info=True)
  1118. return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
  1119. # 模板下载接口
  1120. @bp.route('/download_template', methods=['GET'])
  1121. def download_template():
  1122. """
  1123. 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
  1124. """
  1125. table_name = request.args.get('table')
  1126. if not table_name:
  1127. return jsonify({'error': '表名参数缺失'}), 400
  1128. columns = get_column_names(table_name)
  1129. if not columns:
  1130. return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
  1131. # 不包括 ID 列
  1132. if 'id' in columns:
  1133. columns.remove('id')
  1134. df = pd.DataFrame(columns=columns)
  1135. file_format = request.args.get('format', 'excel').lower()
  1136. try:
  1137. if file_format == 'csv':
  1138. output = BytesIO()
  1139. df.to_csv(output, index=False, encoding='utf-8')
  1140. output.seek(0)
  1141. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
  1142. mimetype='text/csv')
  1143. else:
  1144. output = BytesIO()
  1145. df.to_excel(output, index=False, engine='openpyxl')
  1146. output.seek(0)
  1147. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
  1148. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1149. except Exception as e:
  1150. logger.error(f"Failed to generate template: {e}", exc_info=True)
  1151. return jsonify({'error': '生成模板文件失败'}), 500
  1152. @bp.route('/update-threshold', methods=['POST'])
  1153. def update_threshold():
  1154. """
  1155. 更新训练阈值的API接口
  1156. @body_param threshold: 新的阈值值(整数)
  1157. @return: JSON响应
  1158. """
  1159. try:
  1160. data = request.get_json()
  1161. new_threshold = data.get('threshold')
  1162. # 验证新阈值
  1163. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  1164. return jsonify({
  1165. 'error': '无效的阈值值,必须为正数'
  1166. }), 400
  1167. # 更新当前应用的阈值配置
  1168. current_app.config['THRESHOLD'] = int(new_threshold)
  1169. return jsonify({
  1170. 'success': True,
  1171. 'message': f'阈值已更新为 {new_threshold}',
  1172. 'new_threshold': new_threshold
  1173. })
  1174. except Exception as e:
  1175. logging.error(f"更新阈值失败: {str(e)}")
  1176. return jsonify({
  1177. 'error': f'更新阈值失败: {str(e)}'
  1178. }), 500
  1179. @bp.route('/get-threshold', methods=['GET'])
  1180. def get_threshold():
  1181. """
  1182. 获取当前训练阈值的API接口
  1183. @return: JSON响应
  1184. """
  1185. try:
  1186. current_threshold = current_app.config['THRESHOLD']
  1187. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  1188. return jsonify({
  1189. 'current_threshold': current_threshold,
  1190. 'default_threshold': default_threshold
  1191. })
  1192. except Exception as e:
  1193. logging.error(f"获取阈值失败: {str(e)}")
  1194. return jsonify({
  1195. 'error': f'获取阈值失败: {str(e)}'
  1196. }), 500
  1197. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  1198. def set_current_dataset(data_type, dataset_id):
  1199. """
  1200. 将指定数据集设置为current数据集
  1201. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1202. @param dataset_id: 要设置为current的数据集ID
  1203. @return: JSON响应
  1204. """
  1205. Session = sessionmaker(bind=db.engine)
  1206. session = Session()
  1207. try:
  1208. # 验证数据集存在且类型匹配
  1209. dataset = session.query(Datasets)\
  1210. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  1211. .first()
  1212. if not dataset:
  1213. return jsonify({
  1214. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  1215. }), 404
  1216. # 根据数据类型选择表
  1217. if data_type == 'reduce':
  1218. table = CurrentReduce
  1219. table_name = 'current_reduce'
  1220. elif data_type == 'reflux':
  1221. table = CurrentReflux
  1222. table_name = 'current_reflux'
  1223. else:
  1224. return jsonify({'error': '无效的数据集类型'}), 400
  1225. # 清空current表
  1226. session.query(table).delete()
  1227. # 重置自增主键计数器
  1228. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  1229. # 从指定数据集复制数据到current表
  1230. dataset_table_name = f"dataset_{dataset_id}"
  1231. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  1232. session.execute(copy_sql)
  1233. session.commit()
  1234. return jsonify({
  1235. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  1236. 'dataset_id': dataset_id,
  1237. 'dataset_name': dataset.Dataset_name,
  1238. 'row_count': dataset.Row_count
  1239. }), 200
  1240. except Exception as e:
  1241. session.rollback()
  1242. logger.error(f'设置current数据集失败: {str(e)}')
  1243. return jsonify({'error': str(e)}), 500
  1244. finally:
  1245. session.close()
  1246. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  1247. def get_model_history(data_type):
  1248. """
  1249. 获取模型训练历史数据的API接口
  1250. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1251. @return: JSON响应,包含时间序列的模型性能数据
  1252. """
  1253. Session = sessionmaker(bind=db.engine)
  1254. session = Session()
  1255. try:
  1256. # 查询所有自动生成的数据集,按时间排序
  1257. datasets = session.query(Datasets).filter(
  1258. Datasets.Dataset_type == data_type,
  1259. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  1260. ).order_by(Datasets.Uploaded_at).all()
  1261. history_data = []
  1262. for dataset in datasets:
  1263. # 查找对应的自动训练模型
  1264. model = session.query(Models).filter(
  1265. Models.DatasetID == dataset.Dataset_ID,
  1266. Models.Model_name.like(f'auto_trained_{data_type}_%')
  1267. ).first()
  1268. print(model)
  1269. print(dataset)
  1270. if model and model.Performance_score is not None:
  1271. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  1272. created_at = model.Created_at.isoformat() if model.Created_at else None
  1273. history_data.append({
  1274. 'dataset_id': dataset.Dataset_ID,
  1275. 'row_count': dataset.Row_count,
  1276. 'model_id': model.ModelID,
  1277. 'model_name': model.Model_name,
  1278. 'performance_score': float(model.Performance_score),
  1279. 'timestamp': created_at
  1280. })
  1281. # 按时间戳排序
  1282. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  1283. # 构建返回数据,分离各个指标序列便于前端绘图
  1284. response_data = {
  1285. 'data_type': data_type,
  1286. 'timestamps': [item['timestamp'] for item in history_data],
  1287. 'row_counts': [item['row_count'] for item in history_data],
  1288. 'performance_scores': [item['performance_score'] for item in history_data],
  1289. 'model_details': history_data # 保留完整数据供前端使用
  1290. }
  1291. print(response_data)
  1292. return jsonify(response_data), 200
  1293. except Exception as e:
  1294. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  1295. return jsonify({'error': str(e)}), 500
  1296. finally:
  1297. session.close()
  1298. @bp.route('/batch-delete-datasets', methods=['POST'])
  1299. def batch_delete_datasets():
  1300. """
  1301. 批量删除数据集的API接口
  1302. @body_param dataset_ids: 要删除的数据集ID列表
  1303. @return: JSON响应
  1304. """
  1305. try:
  1306. data = request.get_json()
  1307. dataset_ids = data.get('dataset_ids', [])
  1308. if not dataset_ids:
  1309. return jsonify({'error': '未提供数据集ID列表'}), 400
  1310. results = {
  1311. 'success': [],
  1312. 'failed': [],
  1313. 'protected': [] # 被模型使用的数据集
  1314. }
  1315. for dataset_id in dataset_ids:
  1316. try:
  1317. # 调用单个删除接口
  1318. response = delete_dataset_endpoint(dataset_id)
  1319. # 解析响应
  1320. if response[1] == 200:
  1321. results['success'].append(dataset_id)
  1322. elif response[1] == 400 and 'models' in response[0].json:
  1323. # 数据集被模型保护
  1324. results['protected'].append({
  1325. 'id': dataset_id,
  1326. 'models': response[0].json['models']
  1327. })
  1328. else:
  1329. results['failed'].append({
  1330. 'id': dataset_id,
  1331. 'reason': response[0].json.get('error', '删除失败')
  1332. })
  1333. except Exception as e:
  1334. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  1335. results['failed'].append({
  1336. 'id': dataset_id,
  1337. 'reason': str(e)
  1338. })
  1339. # 构建响应消息
  1340. message = f"成功删除 {len(results['success'])} 个数据集"
  1341. if results['protected']:
  1342. message += f", {len(results['protected'])} 个数据集被保护"
  1343. if results['failed']:
  1344. message += f", {len(results['failed'])} 个数据集删除失败"
  1345. return jsonify({
  1346. 'message': message,
  1347. 'results': results
  1348. }), 200
  1349. except Exception as e:
  1350. logger.error(f'批量删除数据集失败: {str(e)}')
  1351. return jsonify({'error': str(e)}), 500
  1352. @bp.route('/batch-delete-models', methods=['POST'])
  1353. def batch_delete_models():
  1354. """
  1355. 批量删除模型的API接口
  1356. @body_param model_ids: 要删除的模型ID列表
  1357. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  1358. @return: JSON响应
  1359. """
  1360. try:
  1361. data = request.get_json()
  1362. model_ids = data.get('model_ids', [])
  1363. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  1364. if not model_ids:
  1365. return jsonify({'error': '未提供模型ID列表'}), 400
  1366. results = {
  1367. 'success': [],
  1368. 'failed': [],
  1369. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  1370. }
  1371. for model_id in model_ids:
  1372. try:
  1373. # 调用单个删除接口
  1374. response = delete_model(model_id, delete_dataset=delete_datasets)
  1375. # 解析响应
  1376. if response[1] == 200:
  1377. results['success'].append(model_id)
  1378. # 如果删除了关联数据集,记录数据集ID
  1379. if 'dataset_info' in response[0].json:
  1380. results['datasets_deleted'].append(
  1381. response[0].json['dataset_info']['dataset_id']
  1382. )
  1383. else:
  1384. results['failed'].append({
  1385. 'id': model_id,
  1386. 'reason': response[0].json.get('error', '删除失败')
  1387. })
  1388. except Exception as e:
  1389. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  1390. results['failed'].append({
  1391. 'id': model_id,
  1392. 'reason': str(e)
  1393. })
  1394. # 构建响应消息
  1395. message = f"成功删除 {len(results['success'])} 个模型"
  1396. if results['datasets_deleted']:
  1397. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  1398. if results['failed']:
  1399. message += f", {len(results['failed'])} 个模型删除失败"
  1400. return jsonify({
  1401. 'message': message,
  1402. 'results': results
  1403. }), 200
  1404. except Exception as e:
  1405. logger.error(f'批量删除模型失败: {str(e)}')
  1406. return jsonify({'error': str(e)}), 500
  1407. @bp.route('/kriging_interpolation', methods=['POST'])
  1408. def kriging_interpolation():
  1409. try:
  1410. data = request.get_json()
  1411. required = ['file_name', 'emission_column', 'points']
  1412. if not all(k in data for k in required):
  1413. return jsonify({"error": "Missing parameters"}), 400
  1414. # 添加坐标顺序验证
  1415. points = data['points']
  1416. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  1417. return jsonify({"error": "Invalid points format"}), 400
  1418. result = create_kriging(
  1419. data['file_name'],
  1420. data['emission_column'],
  1421. data['points']
  1422. )
  1423. return jsonify(result)
  1424. except Exception as e:
  1425. return jsonify({"error": str(e)}), 500
  1426. # 切换模型接口
  1427. @bp.route('/switch-model', methods=['POST'])
  1428. def switch_model():
  1429. session = None
  1430. try:
  1431. data = request.get_json()
  1432. model_id = data.get('model_id')
  1433. model_name = data.get('model_name')
  1434. # 创建 session
  1435. Session = sessionmaker(bind=db.engine)
  1436. session = Session()
  1437. # 查找模型
  1438. model = session.query(Models).filter_by(ModelID=model_id).first()
  1439. if not model:
  1440. return jsonify({'error': 'Model not found'}), 404
  1441. # 更新模型状态(或其他切换逻辑)
  1442. # 假设此处是更新模型的某些字段来进行切换
  1443. model.status = 'active' # 假设有一个字段记录模型状态
  1444. session.commit()
  1445. # 记录切换日志
  1446. logger.info(f'Model {model_name} (ID: {model_id}) switched successfully.')
  1447. return jsonify({'success': True, 'message': f'Model {model_name} switched successfully!'}), 200
  1448. except Exception as e:
  1449. logger.error('Failed to switch model:', exc_info=True)
  1450. return jsonify({'error': str(e)}), 400
  1451. finally:
  1452. if session:
  1453. session.close()
  1454. # 添加查看登录状态接口
  1455. @bp.route('/getInfo', methods=['GET'])
  1456. def getInfo_user():
  1457. name = session.get("name")
  1458. if name is not None:
  1459. return jsonify(name=name)
  1460. else:
  1461. return jsonify(msg="错误")
  1462. # 添加退出登录状态接口
  1463. @bp.route('/logout', methods=['GET', 'POST'])
  1464. def logout_user():
  1465. try:
  1466. session.clear()
  1467. return jsonify({"msg": "退出成功"}), 200
  1468. except Exception as e:
  1469. logger.error(f"Error during logout process: {e}", exc_info=True)
  1470. return jsonify({"msg": "退出失败"}), 500
  1471. # 获取软件介绍信息的路由
  1472. @bp.route('/software-intro/<int:id>', methods=['GET'])
  1473. def get_software_intro(id):
  1474. try:
  1475. conn = get_db()
  1476. cursor = conn.cursor()
  1477. cursor.execute('SELECT title, intro FROM software_intro WHERE id = ?', (id,))
  1478. result = cursor.fetchone()
  1479. conn.close()
  1480. if result:
  1481. title, intro = result
  1482. return jsonify({
  1483. 'title': title,
  1484. 'intro': intro
  1485. })
  1486. return jsonify({}), 404
  1487. except sqlite3.Error as e:
  1488. print(f"数据库错误: {e}")
  1489. return jsonify({"error": f"数据库错误: {str(e)}"}), 500
  1490. # 更新软件介绍信息的路由
  1491. @bp.route('/software-intro/<int:id>', methods=['PUT'])
  1492. def update_software_intro(id):
  1493. try:
  1494. data = request.get_json()
  1495. title = data.get('title')
  1496. intro = data.get('intro')
  1497. conn = get_db()
  1498. cursor = conn.cursor()
  1499. cursor.execute('UPDATE software_intro SET title =?, intro =? WHERE id = ?', (title, intro, id))
  1500. conn.commit()
  1501. conn.close()
  1502. return jsonify({'message': '软件介绍更新成功'})
  1503. except sqlite3.Error as e:
  1504. print(f"数据库错误: {e}")
  1505. return jsonify({"error": f"数据库错误: {str(e)}"}), 500
  1506. # 处理图片上传的路由
  1507. @bp.route('/upload-image', methods=['POST'])
  1508. def upload_image():
  1509. file = request.files['image']
  1510. if file:
  1511. filename = str(uuid.uuid4()) + '.' + file.filename.rsplit('.', 1)[1].lower()
  1512. file.save(os.path.join(bp.config['UPLOAD_FOLDER'], filename))
  1513. imageUrl = f'http://127.0.0.1:5000/uploads/{filename}'
  1514. return jsonify({'imageUrl': imageUrl})
  1515. return jsonify({'error': '未找到图片文件'}), 400
  1516. # 配置静态资源服务
  1517. @bp.route('/uploads/<path:filename>')
  1518. def serve_image(filename):
  1519. uploads_folder = os.path.join(bp.root_path, 'uploads')
  1520. return send_from_directory(uploads_folder, filename)