routes.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788
  1. import sqlite3
  2. from io import BytesIO
  3. import pickle
  4. from flask import Blueprint, request, jsonify, current_app, send_file
  5. from werkzeug.security import check_password_hash, generate_password_hash
  6. from werkzeug.utils import secure_filename
  7. from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
  8. import pandas as pd
  9. from . import db # 从 app 包导入 db 实例
  10. from sqlalchemy.engine.reflection import Inspector
  11. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  12. import os
  13. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  14. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  15. predict_to_Q, Q_to_t_ha, create_kriging
  16. from sqlalchemy.orm import sessionmaker
  17. import logging
  18. from sqlalchemy import text, select, MetaData, Table, func
  19. from .tasks import train_model_task
  20. from datetime import datetime
  21. from sklearn.metrics import r2_score
  22. # 配置日志
  23. logging.basicConfig(level=logging.DEBUG)
  24. logger = logging.getLogger(__name__)
  25. # 创建蓝图 (Blueprint),用于分离路由
  26. bp = Blueprint('routes', __name__)
  27. # 密码加密
  28. def hash_password(password):
  29. return generate_password_hash(password)
  30. def get_db():
  31. """ 获取数据库连接 """
  32. return sqlite3.connect(current_app.config['DATABASE'])
  33. # 添加一个新的辅助函数来检查数据集大小并触发训练
  34. def check_and_trigger_training(session, dataset_type, dataset_df):
  35. """
  36. 检查当前数据集大小是否跨越新的阈值点并触发训练
  37. Args:
  38. session: 数据库会话
  39. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  40. dataset_df: 数据集 DataFrame
  41. Returns:
  42. tuple: (是否触发训练, 任务ID)
  43. """
  44. try:
  45. # 根据数据集类型选择表
  46. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  47. # 获取当前记录数
  48. current_count = session.query(func.count()).select_from(table).scalar()
  49. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  50. new_records = len(dataset_df) # 需要从上层函数传入
  51. # 计算新增数据前的记录数
  52. previous_count = current_count - new_records
  53. # 设置阈值
  54. THRESHOLD = current_app.config['THRESHOLD']
  55. # 计算上一个阈值点(基于新增前的数据量)
  56. last_threshold = previous_count // THRESHOLD * THRESHOLD
  57. # 计算当前所在阈值点
  58. current_threshold = current_count // THRESHOLD * THRESHOLD
  59. # 检查是否跨越了新的阈值点
  60. if current_threshold > last_threshold and current_count >= THRESHOLD:
  61. # 触发异步训练任务
  62. task = train_model_task.delay(
  63. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  64. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  65. model_description=f'Auto trained model at {current_threshold} records threshold',
  66. data_type=dataset_type
  67. )
  68. return True, task.id
  69. return False, None
  70. except Exception as e:
  71. logging.error(f"检查并触发训练失败: {str(e)}")
  72. return False, None
  73. @bp.route('/upload-dataset', methods=['POST'])
  74. def upload_dataset():
  75. # 创建 session
  76. Session = sessionmaker(bind=db.engine)
  77. session = Session()
  78. try:
  79. if 'file' not in request.files:
  80. return jsonify({'error': 'No file part'}), 400
  81. file = request.files['file']
  82. if file.filename == '' or not allowed_file(file.filename):
  83. return jsonify({'error': 'No selected file or invalid file type'}), 400
  84. dataset_name = request.form.get('dataset_name')
  85. dataset_description = request.form.get('dataset_description', 'No description provided')
  86. dataset_type = request.form.get('dataset_type')
  87. if not dataset_type:
  88. return jsonify({'error': 'Dataset type is required'}), 400
  89. new_dataset = Datasets(
  90. Dataset_name=dataset_name,
  91. Dataset_description=dataset_description,
  92. Row_count=0,
  93. Status='Datasets_upgraded',
  94. Dataset_type=dataset_type,
  95. Uploaded_at=datetime.now()
  96. )
  97. session.add(new_dataset)
  98. session.commit()
  99. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  100. upload_folder = current_app.config['UPLOAD_FOLDER']
  101. file_path = os.path.join(upload_folder, unique_filename)
  102. file.save(file_path)
  103. dataset_df = pd.read_excel(file_path)
  104. new_dataset.Row_count = len(dataset_df)
  105. new_dataset.Status = 'excel_file_saved success'
  106. session.commit()
  107. # 处理列名
  108. dataset_df = clean_column_names(dataset_df)
  109. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  110. column_types = infer_column_types(dataset_df)
  111. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  112. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  113. # 去除上传数据集内部的重复项
  114. original_count = len(dataset_df)
  115. dataset_df = dataset_df.drop_duplicates()
  116. duplicates_in_file = original_count - len(dataset_df)
  117. # 检查与现有数据的重复
  118. duplicates_with_existing = 0
  119. if dataset_type in ['reduce', 'reflux']:
  120. # 确定表名
  121. table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
  122. # 从表加载现有数据
  123. existing_data = pd.read_sql_table(table_name, session.bind)
  124. if 'id' in existing_data.columns:
  125. existing_data = existing_data.drop('id', axis=1)
  126. # 确定用于比较的列
  127. compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
  128. # 计算重复行数
  129. original_df_len = len(dataset_df)
  130. # 使用concat和drop_duplicates找出非重复行
  131. all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
  132. duplicates_mask = all_data.duplicated(keep='first')
  133. duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
  134. # 保留非重复行
  135. dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
  136. logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
  137. # 检查与测试集的重叠
  138. test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
  139. # 如果有与测试集重叠的数据,从数据集中移除
  140. if test_overlap_count > 0:
  141. # 创建一个布尔掩码,标记不在重叠索引中的行
  142. mask = ~dataset_df.index.isin(test_overlap_indices)
  143. # 应用掩码,只保留不重叠的行
  144. dataset_df = dataset_df[mask]
  145. logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
  146. # 根据 dataset_type 决定插入到哪个已有表
  147. if dataset_type == 'reduce':
  148. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  149. elif dataset_type == 'reflux':
  150. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  151. session.commit()
  152. # 在完成数据插入后,检查是否需要触发训练
  153. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  154. response_data = {
  155. 'message': f'数据集 {dataset_name} 上传成功!',
  156. 'dataset_id': new_dataset.Dataset_ID,
  157. 'filename': unique_filename,
  158. 'training_triggered': training_triggered,
  159. 'data_stats': {
  160. 'original_count': original_count,
  161. 'duplicates_in_file': duplicates_in_file,
  162. 'duplicates_with_existing': duplicates_with_existing,
  163. 'test_overlap_count': test_overlap_count,
  164. 'final_count': len(dataset_df)
  165. }
  166. }
  167. if training_triggered:
  168. response_data['task_id'] = task_id
  169. response_data['message'] += ' 自动训练已触发。'
  170. # 添加去重信息到消息中
  171. if duplicates_with_existing > 0:
  172. response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
  173. # 添加测试集重叠信息到消息中
  174. if test_overlap_count > 0:
  175. response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
  176. return jsonify(response_data), 201
  177. except Exception as e:
  178. session.rollback()
  179. logging.error('Failed to process the dataset upload:', exc_info=True)
  180. return jsonify({'error': str(e)}), 500
  181. finally:
  182. # 确保 session 总是被关闭
  183. if session:
  184. session.close()
  185. @bp.route('/train-and-save-model', methods=['POST'])
  186. def train_and_save_model_endpoint():
  187. # 创建 sessionmaker 实例
  188. Session = sessionmaker(bind=db.engine)
  189. session = Session()
  190. data = request.get_json()
  191. # 从请求中解析参数
  192. model_type = data.get('model_type')
  193. model_name = data.get('model_name')
  194. model_description = data.get('model_description')
  195. data_type = data.get('data_type')
  196. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  197. try:
  198. # 调用训练和保存模型的函数
  199. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  200. model_id = result[1] if result else None
  201. # 计算模型评分
  202. if model_id:
  203. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  204. if model_info:
  205. score = calculate_model_score(model_info)
  206. # 更新模型评分
  207. model_info.Performance_score = score
  208. session.commit()
  209. result = {'model_id': model_id, 'model_score': score}
  210. # 返回成功响应
  211. return jsonify({
  212. 'message': 'Model trained and saved successfully',
  213. 'result': result
  214. }), 200
  215. except Exception as e:
  216. session.rollback()
  217. logging.error('Failed to process the model training:', exc_info=True)
  218. return jsonify({
  219. 'error': 'Failed to train and save model',
  220. 'message': str(e)
  221. }), 500
  222. finally:
  223. session.close()
  224. @bp.route('/predict', methods=['POST'])
  225. def predict_route():
  226. # 创建 sessionmaker 实例
  227. Session = sessionmaker(bind=db.engine)
  228. session = Session()
  229. try:
  230. data = request.get_json()
  231. model_id = data.get('model_id') # 提取模型名称
  232. parameters = data.get('parameters', {}) # 提取所有变量
  233. # 根据model_id获取模型Data_type
  234. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  235. if not model_info:
  236. return jsonify({'error': 'Model not found'}), 404
  237. data_type = model_info.Data_type
  238. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  239. # 如果为reduce,则不需要传入target_ph
  240. if data_type == 'reduce':
  241. # 获取传入的init_ph、target_ph参数
  242. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  243. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  244. # 从输入数据中删除'target_pH'列
  245. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  246. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  247. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  248. if data_type == 'reduce':
  249. predictions = predictions[0]
  250. # 将预测结果转换为Q
  251. Q = predict_to_Q(predictions, init_ph, target_ph)
  252. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  253. print(predictions)
  254. return jsonify({'result': predictions}), 200
  255. except Exception as e:
  256. logging.error('Failed to predict:', exc_info=True)
  257. return jsonify({'error': str(e)}), 400
  258. # 为指定模型计算评分Performance_score,需要提供model_id
  259. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  260. def score_model(model_id):
  261. # 创建 sessionmaker 实例
  262. Session = sessionmaker(bind=db.engine)
  263. session = Session()
  264. try:
  265. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  266. if not model_info:
  267. return jsonify({'error': 'Model not found'}), 404
  268. # 计算模型评分
  269. score = calculate_model_score(model_info)
  270. # 更新模型记录中的评分
  271. model_info.Performance_score = score
  272. session.commit()
  273. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  274. except Exception as e:
  275. logging.error('Failed to process the dataset upload:', exc_info=True)
  276. return jsonify({'error': str(e)}), 400
  277. finally:
  278. session.close()
  279. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  280. def delete_dataset_endpoint(dataset_id):
  281. """
  282. 删除数据集的API接口
  283. @param dataset_id: 要删除的数据集ID
  284. @return: JSON响应
  285. """
  286. # 创建 sessionmaker 实例
  287. Session = sessionmaker(bind=db.engine)
  288. session = Session()
  289. try:
  290. # 查询数据集
  291. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  292. if not dataset:
  293. return jsonify({'error': '未找到数据集'}), 404
  294. # 检查是否有模型使用了该数据集
  295. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  296. if models_using_dataset:
  297. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  298. return jsonify({
  299. 'error': '无法删除数据集,因为以下模型正在使用它',
  300. 'models': models_info
  301. }), 400
  302. # 删除Excel文件
  303. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  304. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  305. if os.path.exists(file_path):
  306. try:
  307. os.remove(file_path)
  308. except OSError as e:
  309. logger.error(f'删除文件失败: {str(e)}')
  310. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  311. # 删除数据表
  312. table_name = f"dataset_{dataset.Dataset_ID}"
  313. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  314. # 删除数据集记录
  315. session.delete(dataset)
  316. session.commit()
  317. return jsonify({
  318. 'message': '数据集删除成功',
  319. 'deleted_files': [filename]
  320. }), 200
  321. except Exception as e:
  322. session.rollback()
  323. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  324. return jsonify({'error': str(e)}), 500
  325. finally:
  326. session.close()
  327. @bp.route('/tables', methods=['GET'])
  328. def list_tables():
  329. engine = db.engine # 使用 db 实例的 engine
  330. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  331. table_names = inspector.get_table_names() # 获取所有表名
  332. return jsonify(table_names) # 以 JSON 形式返回表名列表
  333. @bp.route('/models/<int:model_id>', methods=['GET'])
  334. def get_model(model_id):
  335. """
  336. 获取单个模型信息的API接口
  337. @param model_id: 模型ID
  338. @return: JSON响应
  339. """
  340. Session = sessionmaker(bind=db.engine)
  341. session = Session()
  342. try:
  343. model = session.query(Models).filter_by(ModelID=model_id).first()
  344. if model:
  345. return jsonify({
  346. 'ModelID': model.ModelID,
  347. 'Model_name': model.Model_name,
  348. 'Model_type': model.Model_type,
  349. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  350. 'Description': model.Description,
  351. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  352. 'Data_type': model.Data_type
  353. })
  354. else:
  355. return jsonify({'message': '未找到模型'}), 404
  356. except Exception as e:
  357. logger.error(f'获取模型信息失败: {str(e)}')
  358. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  359. finally:
  360. session.close()
  361. @bp.route('/models', methods=['GET'])
  362. def get_all_models():
  363. """
  364. 获取所有模型信息的API接口
  365. @return: JSON响应
  366. """
  367. Session = sessionmaker(bind=db.engine)
  368. session = Session()
  369. try:
  370. models = session.query(Models).all()
  371. if models:
  372. result = [
  373. {
  374. 'ModelID': model.ModelID,
  375. 'Model_name': model.Model_name,
  376. 'Model_type': model.Model_type,
  377. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  378. 'Description': model.Description,
  379. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  380. 'Data_type': model.Data_type
  381. }
  382. for model in models
  383. ]
  384. return jsonify(result)
  385. else:
  386. return jsonify({'message': '未找到任何模型'}), 404
  387. except Exception as e:
  388. logger.error(f'获取所有模型信息失败: {str(e)}')
  389. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  390. finally:
  391. session.close()
  392. @bp.route('/model-parameters', methods=['GET'])
  393. def get_all_model_parameters():
  394. """
  395. 获取所有模型参数的API接口
  396. @return: JSON响应
  397. """
  398. Session = sessionmaker(bind=db.engine)
  399. session = Session()
  400. try:
  401. parameters = session.query(ModelParameters).all()
  402. if parameters:
  403. result = [
  404. {
  405. 'ParamID': param.ParamID,
  406. 'ModelID': param.ModelID,
  407. 'ParamName': param.ParamName,
  408. 'ParamValue': param.ParamValue
  409. }
  410. for param in parameters
  411. ]
  412. return jsonify(result)
  413. else:
  414. return jsonify({'message': '未找到任何参数'}), 404
  415. except Exception as e:
  416. logger.error(f'获取所有模型参数失败: {str(e)}')
  417. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  418. finally:
  419. session.close()
  420. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  421. def get_model_parameters(model_id):
  422. try:
  423. model = Models.query.filter_by(ModelID=model_id).first()
  424. if model:
  425. # 获取该模型的所有参数
  426. parameters = [
  427. {
  428. 'ParamID': param.ParamID,
  429. 'ParamName': param.ParamName,
  430. 'ParamValue': param.ParamValue
  431. }
  432. for param in model.parameters
  433. ]
  434. # 返回模型参数信息
  435. return jsonify({
  436. 'ModelID': model.ModelID,
  437. 'ModelName': model.ModelName,
  438. 'ModelType': model.ModelType,
  439. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  440. 'Description': model.Description,
  441. 'Parameters': parameters
  442. })
  443. else:
  444. return jsonify({'message': 'Model not found'}), 404
  445. except Exception as e:
  446. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  447. # 定义添加数据库记录的 API 接口
  448. @bp.route('/add_item', methods=['POST'])
  449. def add_item():
  450. """
  451. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  452. 尝试将数据插入到指定的表中,并进行字段查重。
  453. :return:
  454. """
  455. try:
  456. # 确保请求体是 JSON 格式
  457. data = request.get_json()
  458. if not data:
  459. raise ValueError("No JSON data provided")
  460. table_name = data.get('table')
  461. item_data = data.get('item')
  462. if not table_name or not item_data:
  463. return jsonify({'error': 'Missing table name or item data'}), 400
  464. # 定义各个表的字段查重规则
  465. duplicate_check_rules = {
  466. 'users': ['email', 'username'],
  467. 'products': ['product_code'],
  468. 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
  469. 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
  470. # 其他表和规则
  471. }
  472. # 获取该表的查重字段
  473. duplicate_columns = duplicate_check_rules.get(table_name)
  474. if not duplicate_columns:
  475. return jsonify({'error': 'No duplicate check rule for this table'}), 400
  476. # 动态构建查询条件,逐一检查是否有重复数据
  477. condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
  478. duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
  479. result = db.session.execute(text(duplicate_query), item_data).fetchone()
  480. if result:
  481. return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
  482. # 动态构建 SQL 语句,进行插入操作
  483. columns = ', '.join(item_data.keys())
  484. placeholders = ', '.join([f":{key}" for key in item_data.keys()])
  485. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  486. # 直接执行插入操作,无需显式的事务管理
  487. db.session.execute(text(sql), item_data)
  488. # 提交事务
  489. db.session.commit()
  490. # 返回成功响应
  491. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  492. except ValueError as e:
  493. return jsonify({'error': str(e)}), 400
  494. except KeyError as e:
  495. return jsonify({'error': f'Missing data field: {e}'}), 400
  496. except sqlite3.IntegrityError as e:
  497. return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
  498. except sqlite3.Error as e:
  499. return jsonify({'error': '数据库错误', 'details': str(e)}), 500
  500. @bp.route('/delete_item', methods=['POST'])
  501. def delete_item():
  502. """
  503. 删除数据库记录的 API 接口
  504. """
  505. data = request.get_json()
  506. table_name = data.get('table')
  507. condition = data.get('condition')
  508. # 检查表名和条件是否提供
  509. if not table_name or not condition:
  510. return jsonify({
  511. "success": False,
  512. "message": "缺少表名或条件参数"
  513. }), 400
  514. # 尝试从条件字符串中解析键和值
  515. try:
  516. key, value = condition.split('=')
  517. key = key.strip() # 去除多余的空格
  518. value = value.strip().strip("'\"") # 去除多余的空格和引号
  519. except ValueError:
  520. return jsonify({
  521. "success": False,
  522. "message": "条件格式错误,应为 'key=value'"
  523. }), 400
  524. # 准备 SQL 删除语句
  525. sql = f"DELETE FROM {table_name} WHERE {key} = :value"
  526. try:
  527. # 使用 SQLAlchemy 执行删除
  528. with db.session.begin():
  529. result = db.session.execute(text(sql), {"value": value})
  530. # 检查是否有记录被删除
  531. if result.rowcount == 0:
  532. return jsonify({
  533. "success": False,
  534. "message": "未找到符合条件的记录"
  535. }), 404
  536. return jsonify({
  537. "success": True,
  538. "message": "记录删除成功"
  539. }), 200
  540. except Exception as e:
  541. return jsonify({
  542. "success": False,
  543. "message": f"删除失败: {e}"
  544. }), 500
  545. # 定义修改数据库记录的 API 接口
  546. @bp.route('/update_item', methods=['PUT'])
  547. def update_record():
  548. """
  549. 接收 JSON 格式的请求体,包含表名和更新的数据。
  550. 尝试更新指定的记录。
  551. """
  552. data = request.get_json()
  553. # 检查必要的数据是否提供
  554. if not data or 'table' not in data or 'item' not in data:
  555. return jsonify({
  556. "success": False,
  557. "message": "请求数据不完整"
  558. }), 400
  559. table_name = data['table']
  560. item = data['item']
  561. # 假设 item 的第一个键是 ID
  562. id_key = next(iter(item.keys())) # 获取第一个键
  563. record_id = item.get(id_key)
  564. if not record_id:
  565. return jsonify({
  566. "success": False,
  567. "message": "缺少记录 ID"
  568. }), 400
  569. # 获取更新的字段和值
  570. updates = {key: value for key, value in item.items() if key != id_key}
  571. if not updates:
  572. return jsonify({
  573. "success": False,
  574. "message": "没有提供需要更新的字段"
  575. }), 400
  576. # 动态构建 SQL
  577. set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
  578. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
  579. # 添加 ID 到参数
  580. updates['id_value'] = record_id
  581. try:
  582. # 使用 SQLAlchemy 执行更新
  583. with db.session.begin():
  584. result = db.session.execute(text(sql), updates)
  585. # 检查是否有更新的记录
  586. if result.rowcount == 0:
  587. return jsonify({
  588. "success": False,
  589. "message": "未找到要更新的记录"
  590. }), 404
  591. return jsonify({
  592. "success": True,
  593. "message": "数据更新成功"
  594. }), 200
  595. except Exception as e:
  596. # 捕获所有异常并返回
  597. return jsonify({
  598. "success": False,
  599. "message": f"更新失败: {str(e)}"
  600. }), 500
  601. # 定义查询数据库记录的 API 接口
  602. @bp.route('/search/record', methods=['GET'])
  603. def sql_search():
  604. """
  605. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  606. 尝试查询指定 ID 的记录并返回结果。
  607. :return:
  608. """
  609. try:
  610. data = request.get_json()
  611. # 表名
  612. sql_table = data['table']
  613. # 要搜索的 ID
  614. Id = data['id']
  615. # 连接到数据库
  616. cur = db.cursor()
  617. # 构造查询语句
  618. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  619. # 执行查询
  620. cur.execute(sql, (Id,))
  621. # 获取查询结果
  622. rows = cur.fetchall()
  623. column_names = [desc[0] for desc in cur.description]
  624. # 检查是否有结果
  625. if not rows:
  626. return jsonify({'error': '未查找到对应数据。'}), 400
  627. # 构造响应数据
  628. results = []
  629. for row in rows:
  630. result = {column_names[i]: row[i] for i in range(len(row))}
  631. results.append(result)
  632. # 关闭游标和数据库连接
  633. cur.close()
  634. db.close()
  635. # 返回 JSON 响应
  636. return jsonify(results), 200
  637. except sqlite3.Error as e:
  638. # 如果发生数据库错误,返回错误信息
  639. return jsonify({'error': str(e)}), 400
  640. except KeyError as e:
  641. # 如果请求数据中缺少必要的键,返回错误信息
  642. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  643. # 定义提供数据库列表,用于展示表格的 API 接口
  644. @bp.route('/table', methods=['POST'])
  645. def get_table():
  646. data = request.get_json()
  647. table_name = data.get('table')
  648. if not table_name:
  649. return jsonify({'error': '需要表名'}), 400
  650. try:
  651. # 创建 sessionmaker 实例
  652. Session = sessionmaker(bind=db.engine)
  653. session = Session()
  654. # 动态获取表的元数据
  655. metadata = MetaData()
  656. table = Table(table_name, metadata, autoload_with=db.engine)
  657. # 从数据库中查询所有记录
  658. query = select(table)
  659. result = session.execute(query).fetchall()
  660. # 将结果转换为列表字典形式
  661. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  662. # 获取列名
  663. headers = [column.name for column in table.columns]
  664. return jsonify(rows=rows, headers=headers), 200
  665. except Exception as e:
  666. return jsonify({'error': str(e)}), 400
  667. finally:
  668. # 关闭 session
  669. session.close()
  670. @bp.route('/train-model-async', methods=['POST'])
  671. def train_model_async():
  672. """
  673. 异步训练模型的API接口
  674. """
  675. try:
  676. data = request.get_json()
  677. # 从请求中获取参数
  678. model_type = data.get('model_type')
  679. model_name = data.get('model_name')
  680. model_description = data.get('model_description')
  681. data_type = data.get('data_type')
  682. dataset_id = data.get('dataset_id', None)
  683. # 验证必要参数
  684. if not all([model_type, model_name, data_type]):
  685. return jsonify({
  686. 'error': 'Missing required parameters'
  687. }), 400
  688. # 如果提供了dataset_id,验证数据集是否存在
  689. if dataset_id:
  690. Session = sessionmaker(bind=db.engine)
  691. session = Session()
  692. try:
  693. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  694. if not dataset:
  695. return jsonify({
  696. 'error': f'Dataset with ID {dataset_id} not found'
  697. }), 404
  698. finally:
  699. session.close()
  700. # 启动异步任务
  701. task = train_model_task.delay(
  702. model_type=model_type,
  703. model_name=model_name,
  704. model_description=model_description,
  705. data_type=data_type,
  706. dataset_id=dataset_id
  707. )
  708. # 返回任务ID
  709. return jsonify({
  710. 'task_id': task.id,
  711. 'message': 'Model training started'
  712. }), 202
  713. except Exception as e:
  714. logging.error('Failed to start async training task:', exc_info=True)
  715. return jsonify({
  716. 'error': str(e)
  717. }), 500
  718. @bp.route('/task-status/<task_id>', methods=['GET'])
  719. def get_task_status(task_id):
  720. """
  721. 获取异步任务状态的API接口
  722. """
  723. try:
  724. task = train_model_task.AsyncResult(task_id)
  725. if task.state == 'PENDING':
  726. response = {
  727. 'state': task.state,
  728. 'status': 'Task is waiting for execution'
  729. }
  730. elif task.state == 'FAILURE':
  731. response = {
  732. 'state': task.state,
  733. 'status': 'Task failed',
  734. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  735. }
  736. elif task.state == 'SUCCESS':
  737. response = {
  738. 'state': task.state,
  739. 'status': 'Task completed successfully',
  740. 'result': task.get()
  741. }
  742. else:
  743. response = {
  744. 'state': task.state,
  745. 'status': 'Task is in progress'
  746. }
  747. return jsonify(response), 200
  748. except Exception as e:
  749. return jsonify({
  750. 'error': str(e)
  751. }), 500
  752. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  753. def delete_model_route(model_id):
  754. # 将URL参数转换为布尔值
  755. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  756. # 调用原始函数
  757. return delete_model(model_id, delete_dataset=delete_dataset_param)
  758. def delete_model(model_id, delete_dataset=False):
  759. """
  760. 删除指定模型的API接口
  761. @param model_id: 要删除的模型ID
  762. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  763. @return: JSON响应
  764. """
  765. Session = sessionmaker(bind=db.engine)
  766. session = Session()
  767. try:
  768. # 查询模型信息
  769. model = session.query(Models).filter_by(ModelID=model_id).first()
  770. if not model:
  771. return jsonify({'error': '未找到指定模型'}), 404
  772. dataset_id = model.DatasetID
  773. # 1. 先删除模型记录
  774. session.delete(model)
  775. session.commit()
  776. # 2. 删除模型文件
  777. model_path = model.ModelFilePath
  778. try:
  779. if os.path.exists(model_path):
  780. os.remove(model_path)
  781. else:
  782. # 如果删除文件失败,回滚数据库操作
  783. session.rollback()
  784. logger.warning(f'模型文件不存在: {model_path}')
  785. except OSError as e:
  786. # 如果删除文件失败,回滚数据库操作
  787. session.rollback()
  788. logger.error(f'删除模型文件失败: {str(e)}')
  789. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  790. # 3. 如果需要删除关联的数据集
  791. if delete_dataset and dataset_id:
  792. try:
  793. dataset_response = delete_dataset_endpoint(dataset_id)
  794. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  795. # 如果删除数据集失败,回滚之前的操作
  796. session.rollback()
  797. return jsonify({
  798. 'error': '删除关联数据集失败',
  799. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  800. }), 500
  801. except Exception as e:
  802. session.rollback()
  803. logger.error(f'删除关联数据集失败: {str(e)}')
  804. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  805. response_data = {
  806. 'message': '模型删除成功',
  807. 'deleted_files': [model_path]
  808. }
  809. if delete_dataset:
  810. response_data['dataset_info'] = {
  811. 'dataset_id': dataset_id,
  812. 'message': '关联数据集已删除'
  813. }
  814. return jsonify(response_data), 200
  815. except Exception as e:
  816. session.rollback()
  817. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  818. return jsonify({'error': str(e)}), 500
  819. finally:
  820. session.close()
  821. # 添加一个新的API端点来清空指定数据集
  822. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  823. def clear_dataset(data_type):
  824. """
  825. 清空指定类型的数据集并递增计数
  826. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  827. @return: JSON响应
  828. """
  829. # 创建 sessionmaker 实例
  830. Session = sessionmaker(bind=db.engine)
  831. session = Session()
  832. try:
  833. # 根据数据集类型选择表
  834. if data_type == 'reduce':
  835. table = CurrentReduce
  836. table_name = 'current_reduce'
  837. elif data_type == 'reflux':
  838. table = CurrentReflux
  839. table_name = 'current_reflux'
  840. else:
  841. return jsonify({'error': '无效的数据集类型'}), 400
  842. # 清空表内容
  843. session.query(table).delete()
  844. # 重置自增主键计数器
  845. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  846. session.commit()
  847. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  848. except Exception as e:
  849. session.rollback()
  850. return jsonify({'error': str(e)}), 500
  851. finally:
  852. session.close()
  853. @bp.route('/login', methods=['POST'])
  854. def login_user():
  855. # 获取前端传来的数据
  856. data = request.get_json()
  857. name = data.get('name') # 用户名
  858. password = data.get('password') # 密码
  859. logger.info(f"Login request received: name={name}")
  860. # 检查用户名和密码是否为空
  861. if not name or not password:
  862. logger.warning("用户名和密码不能为空")
  863. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  864. try:
  865. # 查询数据库验证用户名
  866. query = "SELECT * FROM users WHERE name = :name"
  867. conn = get_db()
  868. user = conn.execute(query, {"name": name}).fetchone()
  869. if not user:
  870. logger.warning(f"用户名 '{name}' 不存在")
  871. return jsonify({"success": False, "message": "用户名不存在"}), 400
  872. # 获取数据库中存储的密码(假设密码是哈希存储的)
  873. stored_password = user[2] # 假设密码存储在数据库的第三列
  874. user_id = user[0] # 假设 id 存储在数据库的第一列
  875. # 校验密码是否正确
  876. if check_password_hash(stored_password, password):
  877. logger.info(f"User '{name}' logged in successfully.")
  878. return jsonify({
  879. "success": True,
  880. "message": "登录成功",
  881. "userId": user_id # 返回用户 ID
  882. })
  883. else:
  884. logger.warning(f"Invalid password for user '{name}'")
  885. return jsonify({"success": False, "message": "用户名或密码错误"}), 400
  886. except Exception as e:
  887. # 记录错误日志并返回错误信息
  888. logger.error(f"Error during login: {e}", exc_info=True)
  889. return jsonify({"success": False, "message": "登录失败"}), 500
  890. # 更新用户信息接口
  891. @bp.route('/update_user', methods=['POST'])
  892. def update_user():
  893. # 获取前端传来的数据
  894. data = request.get_json()
  895. # 打印收到的请求数据
  896. current_app.logger.info(f"Received data: {data}")
  897. user_id = data.get('userId') # 用户ID
  898. name = data.get('name') # 用户名
  899. old_password = data.get('oldPassword') # 旧密码
  900. new_password = data.get('newPassword') # 新密码
  901. logger.info(f"Update request received: user_id={user_id}, name={name}")
  902. # 校验传入的用户名和密码是否为空
  903. if not name or not old_password:
  904. logger.warning("用户名和旧密码不能为空")
  905. return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
  906. # 新密码和旧密码不能相同
  907. if new_password and old_password == new_password:
  908. logger.warning(f"新密码与旧密码相同:{name}")
  909. return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
  910. try:
  911. # 查询数据库验证用户ID
  912. query = "SELECT * FROM users WHERE id = :user_id"
  913. conn = get_db()
  914. user = conn.execute(query, {"user_id": user_id}).fetchone()
  915. if not user:
  916. logger.warning(f"用户ID '{user_id}' 不存在")
  917. return jsonify({"success": False, "message": "用户不存在"}), 400
  918. # 获取数据库中存储的密码(假设密码是哈希存储的)
  919. stored_password = user[2] # 假设密码存储在数据库的第三列
  920. # 校验旧密码是否正确
  921. if not check_password_hash(stored_password, old_password):
  922. logger.warning(f"旧密码错误:{name}")
  923. return jsonify({"success": False, "message": "旧密码错误"}), 400
  924. # 如果新密码非空,则更新新密码
  925. if new_password:
  926. hashed_new_password = hash_password(new_password)
  927. update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
  928. conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
  929. conn.commit()
  930. logger.info(f"User ID '{user_id}' password updated successfully.")
  931. # 如果用户名发生更改,则更新用户名
  932. if name != user[1]:
  933. update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
  934. conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
  935. conn.commit()
  936. logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
  937. return jsonify({"success": True, "message": "用户信息更新成功"})
  938. except Exception as e:
  939. # 记录错误日志并返回错误信息
  940. logger.error(f"Error updating user: {e}", exc_info=True)
  941. return jsonify({"success": False, "message": "更新失败"}), 500
  942. # 注册用户
  943. @bp.route('/register', methods=['POST'])
  944. def register_user():
  945. # 获取前端传来的数据
  946. data = request.get_json()
  947. name = data.get('name') # 用户名
  948. password = data.get('password') # 密码
  949. logger.info(f"Register request received: name={name}")
  950. # 检查用户名和密码是否为空
  951. if not name or not password:
  952. logger.warning("用户名和密码不能为空")
  953. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  954. # 动态获取数据库表的列名
  955. columns = get_column_names('users')
  956. logger.info(f"Database columns for 'users' table: {columns}")
  957. # 检查前端传来的数据是否包含数据库表中所有的必填字段
  958. for column in ['name', 'password']:
  959. if column not in columns:
  960. logger.error(f"缺少必填字段:{column}")
  961. return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
  962. # 对密码进行哈希处理
  963. hashed_password = hash_password(password)
  964. logger.info(f"Password hashed for user: {name}")
  965. # 插入到数据库
  966. try:
  967. # 检查用户是否已经存在
  968. query = "SELECT * FROM users WHERE name = :name"
  969. conn = get_db()
  970. user = conn.execute(query, {"name": name}).fetchone()
  971. if user:
  972. logger.warning(f"用户名 '{name}' 已存在")
  973. return jsonify({"success": False, "message": "用户名已存在"}), 400
  974. # 向数据库插入数据
  975. query = "INSERT INTO users (name, password) VALUES (:name, :password)"
  976. conn.execute(query, {"name": name, "password": hashed_password})
  977. conn.commit()
  978. logger.info(f"User '{name}' registered successfully.")
  979. return jsonify({"success": True, "message": "注册成功"})
  980. except Exception as e:
  981. # 记录错误日志并返回错误信息
  982. logger.error(f"Error registering user: {e}", exc_info=True)
  983. return jsonify({"success": False, "message": "注册失败"}), 500
  984. def get_column_names(table_name):
  985. """
  986. 动态获取数据库表的列名。
  987. """
  988. try:
  989. conn = get_db()
  990. query = f"PRAGMA table_info({table_name});"
  991. result = conn.execute(query).fetchall()
  992. conn.close()
  993. return [row[1] for row in result] # 第二列是列名
  994. except Exception as e:
  995. logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
  996. return []
  997. # 导出数据
  998. @bp.route('/export_data', methods=['GET'])
  999. def export_data():
  1000. table_name = request.args.get('table')
  1001. file_format = request.args.get('format', 'excel').lower()
  1002. if not table_name:
  1003. return jsonify({'error': '缺少表名参数'}), 400
  1004. if not table_name.isidentifier():
  1005. return jsonify({'error': '无效的表名'}), 400
  1006. try:
  1007. conn = get_db()
  1008. query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
  1009. table_exists = conn.execute(query, (table_name,)).fetchone()
  1010. if not table_exists:
  1011. return jsonify({'error': f"表 {table_name} 不存在"}), 404
  1012. query = f"SELECT * FROM {table_name};"
  1013. df = pd.read_sql(query, conn)
  1014. output = BytesIO()
  1015. if file_format == 'csv':
  1016. df.to_csv(output, index=False, encoding='utf-8')
  1017. output.seek(0)
  1018. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
  1019. elif file_format == 'excel':
  1020. df.to_excel(output, index=False, engine='openpyxl')
  1021. output.seek(0)
  1022. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
  1023. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1024. else:
  1025. return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
  1026. except Exception as e:
  1027. logger.error(f"Error in export_data: {e}", exc_info=True)
  1028. return jsonify({'error': str(e)}), 500
  1029. # 导入数据接口
  1030. @bp.route('/import_data', methods=['POST'])
  1031. def import_data():
  1032. logger.debug("Import data endpoint accessed.")
  1033. if 'file' not in request.files:
  1034. logger.error("No file in request.")
  1035. return jsonify({'success': False, 'message': '文件缺失'}), 400
  1036. file = request.files['file']
  1037. table_name = request.form.get('table')
  1038. if not table_name:
  1039. logger.error("Missing table name parameter.")
  1040. return jsonify({'success': False, 'message': '缺少表名参数'}), 400
  1041. if file.filename == '':
  1042. logger.error("No file selected.")
  1043. return jsonify({'success': False, 'message': '未选择文件'}), 400
  1044. try:
  1045. # 保存文件到临时路径
  1046. temp_path = os.path.join(current_app.config['UPLOAD_FOLDER'], secure_filename(file.filename))
  1047. file.save(temp_path)
  1048. logger.debug(f"File saved to temporary path: {temp_path}")
  1049. # 根据文件类型读取文件
  1050. if file.filename.endswith('.xlsx'):
  1051. df = pd.read_excel(temp_path)
  1052. elif file.filename.endswith('.csv'):
  1053. df = pd.read_csv(temp_path)
  1054. else:
  1055. logger.error("Unsupported file format.")
  1056. return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
  1057. # 获取数据库列名
  1058. db_columns = get_column_names(table_name)
  1059. if 'id' in db_columns:
  1060. db_columns.remove('id') # 假设 id 列是自增的,不需要处理
  1061. if not set(db_columns).issubset(set(df.columns)):
  1062. logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
  1063. return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
  1064. # 清洗数据并删除空值行
  1065. df_cleaned = df[db_columns].dropna()
  1066. # 统一数据类型,避免 int 和 float 合并问题
  1067. df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
  1068. # 获取现有的数据
  1069. conn = get_db()
  1070. with conn:
  1071. existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
  1072. # 查找重复数据
  1073. duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
  1074. # 如果有重复数据,删除它们
  1075. df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
  1076. logger.warning(f"Duplicate data detected and removed: {duplicates}")
  1077. # 获取导入前后的数据量
  1078. total_data = len(df_cleaned) + len(duplicates)
  1079. new_data = len(df_cleaned)
  1080. duplicate_data = len(duplicates)
  1081. # 导入不重复的数据
  1082. df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
  1083. logger.debug(f"Imported {new_data} new records into the database.")
  1084. # 删除临时文件
  1085. os.remove(temp_path)
  1086. logger.debug(f"Temporary file removed: {temp_path}")
  1087. # 返回结果
  1088. return jsonify({
  1089. 'success': True,
  1090. 'message': '数据导入成功',
  1091. 'total_data': total_data,
  1092. 'new_data': new_data,
  1093. 'duplicate_data': duplicate_data
  1094. }), 200
  1095. except Exception as e:
  1096. logger.error(f"Import failed: {e}", exc_info=True)
  1097. return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
  1098. # 模板下载接口
  1099. @bp.route('/download_template', methods=['GET'])
  1100. def download_template():
  1101. """
  1102. 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
  1103. """
  1104. table_name = request.args.get('table')
  1105. if not table_name:
  1106. return jsonify({'error': '表名参数缺失'}), 400
  1107. columns = get_column_names(table_name)
  1108. if not columns:
  1109. return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
  1110. # 不包括 ID 列
  1111. if 'id' in columns:
  1112. columns.remove('id')
  1113. df = pd.DataFrame(columns=columns)
  1114. file_format = request.args.get('format', 'excel').lower()
  1115. try:
  1116. if file_format == 'csv':
  1117. output = BytesIO()
  1118. df.to_csv(output, index=False, encoding='utf-8')
  1119. output.seek(0)
  1120. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
  1121. mimetype='text/csv')
  1122. else:
  1123. output = BytesIO()
  1124. df.to_excel(output, index=False, engine='openpyxl')
  1125. output.seek(0)
  1126. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
  1127. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1128. except Exception as e:
  1129. logger.error(f"Failed to generate template: {e}", exc_info=True)
  1130. return jsonify({'error': '生成模板文件失败'}), 500
  1131. @bp.route('/update-threshold', methods=['POST'])
  1132. def update_threshold():
  1133. """
  1134. 更新训练阈值的API接口
  1135. @body_param threshold: 新的阈值值(整数)
  1136. @return: JSON响应
  1137. """
  1138. try:
  1139. data = request.get_json()
  1140. new_threshold = data.get('threshold')
  1141. # 验证新阈值
  1142. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  1143. return jsonify({
  1144. 'error': '无效的阈值值,必须为正数'
  1145. }), 400
  1146. # 更新当前应用的阈值配置
  1147. current_app.config['THRESHOLD'] = int(new_threshold)
  1148. return jsonify({
  1149. 'success': True,
  1150. 'message': f'阈值已更新为 {new_threshold}',
  1151. 'new_threshold': new_threshold
  1152. })
  1153. except Exception as e:
  1154. logging.error(f"更新阈值失败: {str(e)}")
  1155. return jsonify({
  1156. 'error': f'更新阈值失败: {str(e)}'
  1157. }), 500
  1158. @bp.route('/get-threshold', methods=['GET'])
  1159. def get_threshold():
  1160. """
  1161. 获取当前训练阈值的API接口
  1162. @return: JSON响应
  1163. """
  1164. try:
  1165. current_threshold = current_app.config['THRESHOLD']
  1166. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  1167. return jsonify({
  1168. 'current_threshold': current_threshold,
  1169. 'default_threshold': default_threshold
  1170. })
  1171. except Exception as e:
  1172. logging.error(f"获取阈值失败: {str(e)}")
  1173. return jsonify({
  1174. 'error': f'获取阈值失败: {str(e)}'
  1175. }), 500
  1176. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  1177. def set_current_dataset(data_type, dataset_id):
  1178. """
  1179. 将指定数据集设置为current数据集
  1180. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1181. @param dataset_id: 要设置为current的数据集ID
  1182. @return: JSON响应
  1183. """
  1184. Session = sessionmaker(bind=db.engine)
  1185. session = Session()
  1186. try:
  1187. # 验证数据集存在且类型匹配
  1188. dataset = session.query(Datasets)\
  1189. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  1190. .first()
  1191. if not dataset:
  1192. return jsonify({
  1193. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  1194. }), 404
  1195. # 根据数据类型选择表
  1196. if data_type == 'reduce':
  1197. table = CurrentReduce
  1198. table_name = 'current_reduce'
  1199. elif data_type == 'reflux':
  1200. table = CurrentReflux
  1201. table_name = 'current_reflux'
  1202. else:
  1203. return jsonify({'error': '无效的数据集类型'}), 400
  1204. # 清空current表
  1205. session.query(table).delete()
  1206. # 重置自增主键计数器
  1207. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  1208. # 从指定数据集复制数据到current表
  1209. dataset_table_name = f"dataset_{dataset_id}"
  1210. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  1211. session.execute(copy_sql)
  1212. session.commit()
  1213. return jsonify({
  1214. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  1215. 'dataset_id': dataset_id,
  1216. 'dataset_name': dataset.Dataset_name,
  1217. 'row_count': dataset.Row_count
  1218. }), 200
  1219. except Exception as e:
  1220. session.rollback()
  1221. logger.error(f'设置current数据集失败: {str(e)}')
  1222. return jsonify({'error': str(e)}), 500
  1223. finally:
  1224. session.close()
  1225. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  1226. def get_model_history(data_type):
  1227. """
  1228. 获取模型训练历史数据的API接口
  1229. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1230. @return: JSON响应,包含时间序列的模型性能数据
  1231. """
  1232. Session = sessionmaker(bind=db.engine)
  1233. session = Session()
  1234. try:
  1235. # 查询所有自动生成的数据集,按时间排序
  1236. datasets = session.query(Datasets).filter(
  1237. Datasets.Dataset_type == data_type,
  1238. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  1239. ).order_by(Datasets.Uploaded_at).all()
  1240. history_data = []
  1241. for dataset in datasets:
  1242. # 查找对应的自动训练模型
  1243. model = session.query(Models).filter(
  1244. Models.DatasetID == dataset.Dataset_ID,
  1245. Models.Model_name.like(f'auto_trained_{data_type}_%')
  1246. ).first()
  1247. if model and model.Performance_score is not None:
  1248. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  1249. created_at = model.Created_at.isoformat() if model.Created_at else None
  1250. history_data.append({
  1251. 'dataset_id': dataset.Dataset_ID,
  1252. 'row_count': dataset.Row_count,
  1253. 'model_id': model.ModelID,
  1254. 'model_name': model.Model_name,
  1255. 'performance_score': float(model.Performance_score),
  1256. 'timestamp': created_at
  1257. })
  1258. # 按时间戳排序
  1259. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  1260. # 构建返回数据,分离各个指标序列便于前端绘图
  1261. response_data = {
  1262. 'data_type': data_type,
  1263. 'timestamps': [item['timestamp'] for item in history_data],
  1264. 'row_counts': [item['row_count'] for item in history_data],
  1265. 'performance_scores': [item['performance_score'] for item in history_data],
  1266. 'model_details': history_data # 保留完整数据供前端使用
  1267. }
  1268. return jsonify(response_data), 200
  1269. except Exception as e:
  1270. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  1271. return jsonify({'error': str(e)}), 500
  1272. finally:
  1273. session.close()
  1274. @bp.route('/batch-delete-datasets', methods=['POST'])
  1275. def batch_delete_datasets():
  1276. """
  1277. 批量删除数据集的API接口
  1278. @body_param dataset_ids: 要删除的数据集ID列表
  1279. @return: JSON响应
  1280. """
  1281. try:
  1282. data = request.get_json()
  1283. dataset_ids = data.get('dataset_ids', [])
  1284. if not dataset_ids:
  1285. return jsonify({'error': '未提供数据集ID列表'}), 400
  1286. results = {
  1287. 'success': [],
  1288. 'failed': [],
  1289. 'protected': [] # 被模型使用的数据集
  1290. }
  1291. for dataset_id in dataset_ids:
  1292. try:
  1293. # 调用单个删除接口
  1294. response = delete_dataset_endpoint(dataset_id)
  1295. # 解析响应
  1296. if response[1] == 200:
  1297. results['success'].append(dataset_id)
  1298. elif response[1] == 400 and 'models' in response[0].json:
  1299. # 数据集被模型保护
  1300. results['protected'].append({
  1301. 'id': dataset_id,
  1302. 'models': response[0].json['models']
  1303. })
  1304. else:
  1305. results['failed'].append({
  1306. 'id': dataset_id,
  1307. 'reason': response[0].json.get('error', '删除失败')
  1308. })
  1309. except Exception as e:
  1310. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  1311. results['failed'].append({
  1312. 'id': dataset_id,
  1313. 'reason': str(e)
  1314. })
  1315. # 构建响应消息
  1316. message = f"成功删除 {len(results['success'])} 个数据集"
  1317. if results['protected']:
  1318. message += f", {len(results['protected'])} 个数据集被保护"
  1319. if results['failed']:
  1320. message += f", {len(results['failed'])} 个数据集删除失败"
  1321. return jsonify({
  1322. 'message': message,
  1323. 'results': results
  1324. }), 200
  1325. except Exception as e:
  1326. logger.error(f'批量删除数据集失败: {str(e)}')
  1327. return jsonify({'error': str(e)}), 500
  1328. @bp.route('/batch-delete-models', methods=['POST'])
  1329. def batch_delete_models():
  1330. """
  1331. 批量删除模型的API接口
  1332. @body_param model_ids: 要删除的模型ID列表
  1333. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  1334. @return: JSON响应
  1335. """
  1336. try:
  1337. data = request.get_json()
  1338. model_ids = data.get('model_ids', [])
  1339. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  1340. if not model_ids:
  1341. return jsonify({'error': '未提供模型ID列表'}), 400
  1342. results = {
  1343. 'success': [],
  1344. 'failed': [],
  1345. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  1346. }
  1347. for model_id in model_ids:
  1348. try:
  1349. # 调用单个删除接口
  1350. response = delete_model(model_id, delete_dataset=delete_datasets)
  1351. # 解析响应
  1352. if response[1] == 200:
  1353. results['success'].append(model_id)
  1354. # 如果删除了关联数据集,记录数据集ID
  1355. if 'dataset_info' in response[0].json:
  1356. results['datasets_deleted'].append(
  1357. response[0].json['dataset_info']['dataset_id']
  1358. )
  1359. else:
  1360. results['failed'].append({
  1361. 'id': model_id,
  1362. 'reason': response[0].json.get('error', '删除失败')
  1363. })
  1364. except Exception as e:
  1365. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  1366. results['failed'].append({
  1367. 'id': model_id,
  1368. 'reason': str(e)
  1369. })
  1370. # 构建响应消息
  1371. message = f"成功删除 {len(results['success'])} 个模型"
  1372. if results['datasets_deleted']:
  1373. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  1374. if results['failed']:
  1375. message += f", {len(results['failed'])} 个模型删除失败"
  1376. return jsonify({
  1377. 'message': message,
  1378. 'results': results
  1379. }), 200
  1380. except Exception as e:
  1381. logger.error(f'批量删除模型失败: {str(e)}')
  1382. return jsonify({'error': str(e)}), 500
  1383. @bp.route('/kriging_interpolation', methods=['POST'])
  1384. def kriging_interpolation():
  1385. try:
  1386. data = request.get_json()
  1387. required = ['file_name', 'emission_column', 'points']
  1388. if not all(k in data for k in required):
  1389. return jsonify({"error": "Missing parameters"}), 400
  1390. # 添加坐标顺序验证
  1391. points = data['points']
  1392. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  1393. return jsonify({"error": "Invalid points format"}), 400
  1394. result = create_kriging(
  1395. data['file_name'],
  1396. data['emission_column'],
  1397. data['points']
  1398. )
  1399. return jsonify(result)
  1400. except Exception as e:
  1401. return jsonify({"error": str(e)}), 500
  1402. @bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
  1403. def get_model_scatter_data(model_id):
  1404. """
  1405. 获取指定模型的散点图数据(真实值vs预测值)
  1406. @param model_id: 模型ID
  1407. @return: JSON响应,包含散点图数据
  1408. """
  1409. Session = sessionmaker(bind=db.engine)
  1410. session = Session()
  1411. try:
  1412. # 查询模型信息
  1413. model = session.query(Models).filter_by(ModelID=model_id).first()
  1414. if not model:
  1415. return jsonify({'error': '未找到指定模型'}), 404
  1416. # 加载模型
  1417. with open(model.ModelFilePath, 'rb') as f:
  1418. ML_model = pickle.load(f)
  1419. # 根据数据类型加载测试数据
  1420. if model.Data_type == 'reflux':
  1421. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  1422. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  1423. elif model.Data_type == 'reduce':
  1424. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  1425. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  1426. else:
  1427. return jsonify({'error': '不支持的数据类型'}), 400
  1428. # 获取预测值
  1429. y_pred = ML_model.predict(X_test)
  1430. # 生成散点图数据
  1431. scatter_data = [
  1432. [float(true), float(pred)]
  1433. for true, pred in zip(Y_test.iloc[:, 0], y_pred)
  1434. ]
  1435. # 计算R²分数
  1436. r2 = r2_score(Y_test, y_pred)
  1437. # 获取数据范围,用于绘制对角线
  1438. y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
  1439. y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
  1440. return jsonify({
  1441. 'scatter_data': scatter_data,
  1442. 'r2_score': float(r2),
  1443. 'y_range': [float(y_min), float(y_max)],
  1444. 'model_name': model.Model_name,
  1445. 'model_type': model.Model_type
  1446. }), 200
  1447. except Exception as e:
  1448. logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
  1449. return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
  1450. finally:
  1451. session.close()