routes.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772
  1. import sqlite3
  2. from io import BytesIO
  3. import pickle
  4. from flask import Blueprint, request, jsonify, current_app, send_file
  5. from werkzeug.security import check_password_hash, generate_password_hash
  6. from werkzeug.utils import secure_filename
  7. from .model import predict, train_and_save_model, calculate_model_score
  8. import pandas as pd
  9. from . import db # 从 app 包导入 db 实例
  10. from sqlalchemy.engine.reflection import Inspector
  11. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  12. import os
  13. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  14. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  15. predict_to_Q, Q_to_t_ha, create_kriging
  16. from sqlalchemy.orm import sessionmaker
  17. import logging
  18. from sqlalchemy import text, select, MetaData, Table, func
  19. from .tasks import train_model_task
  20. from datetime import datetime
  21. from sklearn.metrics import r2_score
  22. # 配置日志
  23. logging.basicConfig(level=logging.DEBUG)
  24. logger = logging.getLogger(__name__)
  25. # 创建蓝图 (Blueprint),用于分离路由
  26. bp = Blueprint('routes', __name__)
  27. # 密码加密
  28. def hash_password(password):
  29. return generate_password_hash(password)
  30. def get_db():
  31. """ 获取数据库连接 """
  32. return sqlite3.connect(current_app.config['DATABASE'])
  33. # 添加一个新的辅助函数来检查数据集大小并触发训练
  34. def check_and_trigger_training(session, dataset_type, dataset_df):
  35. """
  36. 检查当前数据集大小是否跨越新的阈值点并触发训练
  37. Args:
  38. session: 数据库会话
  39. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  40. dataset_df: 数据集 DataFrame
  41. Returns:
  42. tuple: (是否触发训练, 任务ID)
  43. """
  44. try:
  45. # 根据数据集类型选择表
  46. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  47. # 获取当前记录数
  48. current_count = session.query(func.count()).select_from(table).scalar()
  49. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  50. new_records = len(dataset_df) # 需要从上层函数传入
  51. # 计算新增数据前的记录数
  52. previous_count = current_count - new_records
  53. # 设置阈值
  54. THRESHOLD = current_app.config['THRESHOLD']
  55. # 计算上一个阈值点(基于新增前的数据量)
  56. last_threshold = previous_count // THRESHOLD * THRESHOLD
  57. # 计算当前所在阈值点
  58. current_threshold = current_count // THRESHOLD * THRESHOLD
  59. # 检查是否跨越了新的阈值点
  60. if current_threshold > last_threshold and current_count >= THRESHOLD:
  61. # 触发异步训练任务
  62. task = train_model_task.delay(
  63. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  64. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  65. model_description=f'Auto trained model at {current_threshold} records threshold',
  66. data_type=dataset_type
  67. )
  68. return True, task.id
  69. return False, None
  70. except Exception as e:
  71. logging.error(f"检查并触发训练失败: {str(e)}")
  72. return False, None
  73. @bp.route('/upload-dataset', methods=['POST'])
  74. def upload_dataset():
  75. # 创建 session
  76. Session = sessionmaker(bind=db.engine)
  77. session = Session()
  78. try:
  79. if 'file' not in request.files:
  80. return jsonify({'error': 'No file part'}), 400
  81. file = request.files['file']
  82. if file.filename == '' or not allowed_file(file.filename):
  83. return jsonify({'error': 'No selected file or invalid file type'}), 400
  84. dataset_name = request.form.get('dataset_name')
  85. dataset_description = request.form.get('dataset_description', 'No description provided')
  86. dataset_type = request.form.get('dataset_type')
  87. if not dataset_type:
  88. return jsonify({'error': 'Dataset type is required'}), 400
  89. new_dataset = Datasets(
  90. Dataset_name=dataset_name,
  91. Dataset_description=dataset_description,
  92. Row_count=0,
  93. Status='Datasets_upgraded',
  94. Dataset_type=dataset_type,
  95. Uploaded_at=datetime.now()
  96. )
  97. session.add(new_dataset)
  98. session.commit()
  99. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  100. upload_folder = current_app.config['UPLOAD_FOLDER']
  101. file_path = os.path.join(upload_folder, unique_filename)
  102. file.save(file_path)
  103. dataset_df = pd.read_excel(file_path)
  104. new_dataset.Row_count = len(dataset_df)
  105. new_dataset.Status = 'excel_file_saved success'
  106. session.commit()
  107. # 处理列名
  108. dataset_df = clean_column_names(dataset_df)
  109. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  110. column_types = infer_column_types(dataset_df)
  111. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  112. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  113. # 去除上传数据集内部的重复项
  114. original_count = len(dataset_df)
  115. dataset_df = dataset_df.drop_duplicates()
  116. duplicates_in_file = original_count - len(dataset_df)
  117. # 检查与现有数据的重复
  118. duplicates_with_existing = 0
  119. if dataset_type in ['reduce', 'reflux']:
  120. # 确定表名
  121. table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
  122. # 从表加载现有数据
  123. existing_data = pd.read_sql_table(table_name, session.bind)
  124. if 'id' in existing_data.columns:
  125. existing_data = existing_data.drop('id', axis=1)
  126. # 确定用于比较的列
  127. compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
  128. # 计算重复行数
  129. original_df_len = len(dataset_df)
  130. # 使用concat和drop_duplicates找出非重复行
  131. all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
  132. duplicates_mask = all_data.duplicated(keep='first')
  133. duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
  134. # 保留非重复行
  135. dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
  136. logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
  137. # 根据 dataset_type 决定插入到哪个已有表
  138. if dataset_type == 'reduce':
  139. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  140. elif dataset_type == 'reflux':
  141. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  142. session.commit()
  143. # 在完成数据插入后,检查是否需要触发训练
  144. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  145. response_data = {
  146. 'message': f'数据集 {dataset_name} 上传成功!',
  147. 'dataset_id': new_dataset.Dataset_ID,
  148. 'filename': unique_filename,
  149. 'training_triggered': training_triggered,
  150. 'data_stats': {
  151. 'original_count': original_count,
  152. 'duplicates_in_file': duplicates_in_file,
  153. 'duplicates_with_existing': duplicates_with_existing,
  154. 'final_count': len(dataset_df)
  155. }
  156. }
  157. if training_triggered:
  158. response_data['task_id'] = task_id
  159. response_data['message'] += ' 自动训练已触发。'
  160. # 添加去重信息到消息中
  161. if duplicates_with_existing > 0:
  162. response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
  163. return jsonify(response_data), 201
  164. except Exception as e:
  165. session.rollback()
  166. logging.error('Failed to process the dataset upload:', exc_info=True)
  167. return jsonify({'error': str(e)}), 500
  168. finally:
  169. # 确保 session 总是被关闭
  170. if session:
  171. session.close()
  172. @bp.route('/train-and-save-model', methods=['POST'])
  173. def train_and_save_model_endpoint():
  174. # 创建 sessionmaker 实例
  175. Session = sessionmaker(bind=db.engine)
  176. session = Session()
  177. data = request.get_json()
  178. # 从请求中解析参数
  179. model_type = data.get('model_type')
  180. model_name = data.get('model_name')
  181. model_description = data.get('model_description')
  182. data_type = data.get('data_type')
  183. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  184. try:
  185. # 调用训练和保存模型的函数
  186. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  187. model_id = result[1] if result else None
  188. # 计算模型评分
  189. if model_id:
  190. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  191. if model_info:
  192. score = calculate_model_score(model_info)
  193. # 更新模型评分
  194. model_info.Performance_score = score
  195. session.commit()
  196. result = {'model_id': model_id, 'model_score': score}
  197. # 返回成功响应
  198. return jsonify({
  199. 'message': 'Model trained and saved successfully',
  200. 'result': result
  201. }), 200
  202. except Exception as e:
  203. session.rollback()
  204. logging.error('Failed to process the model training:', exc_info=True)
  205. return jsonify({
  206. 'error': 'Failed to train and save model',
  207. 'message': str(e)
  208. }), 500
  209. finally:
  210. session.close()
  211. @bp.route('/predict', methods=['POST'])
  212. def predict_route():
  213. # 创建 sessionmaker 实例
  214. Session = sessionmaker(bind=db.engine)
  215. session = Session()
  216. try:
  217. data = request.get_json()
  218. model_id = data.get('model_id') # 提取模型名称
  219. parameters = data.get('parameters', {}) # 提取所有变量
  220. # 根据model_id获取模型Data_type
  221. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  222. if not model_info:
  223. return jsonify({'error': 'Model not found'}), 404
  224. data_type = model_info.Data_type
  225. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  226. # 如果为reduce,则不需要传入target_ph
  227. if data_type == 'reduce':
  228. # 获取传入的init_ph、target_ph参数
  229. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  230. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  231. # 从输入数据中删除'target_pH'列
  232. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  233. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  234. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  235. if data_type == 'reduce':
  236. predictions = predictions[0]
  237. # 将预测结果转换为Q
  238. Q = predict_to_Q(predictions, init_ph, target_ph)
  239. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  240. print(predictions)
  241. return jsonify({'result': predictions}), 200
  242. except Exception as e:
  243. logging.error('Failed to predict:', exc_info=True)
  244. return jsonify({'error': str(e)}), 400
  245. # 为指定模型计算评分Performance_score,需要提供model_id
  246. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  247. def score_model(model_id):
  248. # 创建 sessionmaker 实例
  249. Session = sessionmaker(bind=db.engine)
  250. session = Session()
  251. try:
  252. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  253. if not model_info:
  254. return jsonify({'error': 'Model not found'}), 404
  255. # 计算模型评分
  256. score = calculate_model_score(model_info)
  257. # 更新模型记录中的评分
  258. model_info.Performance_score = score
  259. session.commit()
  260. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  261. except Exception as e:
  262. logging.error('Failed to process the dataset upload:', exc_info=True)
  263. return jsonify({'error': str(e)}), 400
  264. finally:
  265. session.close()
  266. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  267. def delete_dataset_endpoint(dataset_id):
  268. """
  269. 删除数据集的API接口
  270. @param dataset_id: 要删除的数据集ID
  271. @return: JSON响应
  272. """
  273. # 创建 sessionmaker 实例
  274. Session = sessionmaker(bind=db.engine)
  275. session = Session()
  276. try:
  277. # 查询数据集
  278. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  279. if not dataset:
  280. return jsonify({'error': '未找到数据集'}), 404
  281. # 检查是否有模型使用了该数据集
  282. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  283. if models_using_dataset:
  284. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  285. return jsonify({
  286. 'error': '无法删除数据集,因为以下模型正在使用它',
  287. 'models': models_info
  288. }), 400
  289. # 删除Excel文件
  290. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  291. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  292. if os.path.exists(file_path):
  293. try:
  294. os.remove(file_path)
  295. except OSError as e:
  296. logger.error(f'删除文件失败: {str(e)}')
  297. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  298. # 删除数据表
  299. table_name = f"dataset_{dataset.Dataset_ID}"
  300. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  301. # 删除数据集记录
  302. session.delete(dataset)
  303. session.commit()
  304. return jsonify({
  305. 'message': '数据集删除成功',
  306. 'deleted_files': [filename]
  307. }), 200
  308. except Exception as e:
  309. session.rollback()
  310. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  311. return jsonify({'error': str(e)}), 500
  312. finally:
  313. session.close()
  314. @bp.route('/tables', methods=['GET'])
  315. def list_tables():
  316. engine = db.engine # 使用 db 实例的 engine
  317. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  318. table_names = inspector.get_table_names() # 获取所有表名
  319. return jsonify(table_names) # 以 JSON 形式返回表名列表
  320. @bp.route('/models/<int:model_id>', methods=['GET'])
  321. def get_model(model_id):
  322. """
  323. 获取单个模型信息的API接口
  324. @param model_id: 模型ID
  325. @return: JSON响应
  326. """
  327. Session = sessionmaker(bind=db.engine)
  328. session = Session()
  329. try:
  330. model = session.query(Models).filter_by(ModelID=model_id).first()
  331. if model:
  332. return jsonify({
  333. 'ModelID': model.ModelID,
  334. 'Model_name': model.Model_name,
  335. 'Model_type': model.Model_type,
  336. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  337. 'Description': model.Description,
  338. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  339. 'Data_type': model.Data_type
  340. })
  341. else:
  342. return jsonify({'message': '未找到模型'}), 404
  343. except Exception as e:
  344. logger.error(f'获取模型信息失败: {str(e)}')
  345. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  346. finally:
  347. session.close()
  348. @bp.route('/models', methods=['GET'])
  349. def get_all_models():
  350. """
  351. 获取所有模型信息的API接口
  352. @return: JSON响应
  353. """
  354. Session = sessionmaker(bind=db.engine)
  355. session = Session()
  356. try:
  357. models = session.query(Models).all()
  358. if models:
  359. result = [
  360. {
  361. 'ModelID': model.ModelID,
  362. 'Model_name': model.Model_name,
  363. 'Model_type': model.Model_type,
  364. 'Created_at': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  365. 'Description': model.Description,
  366. 'Performance_score': float(model.Performance_score) if model.Performance_score else None,
  367. 'Data_type': model.Data_type
  368. }
  369. for model in models
  370. ]
  371. return jsonify(result)
  372. else:
  373. return jsonify({'message': '未找到任何模型'}), 404
  374. except Exception as e:
  375. logger.error(f'获取所有模型信息失败: {str(e)}')
  376. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  377. finally:
  378. session.close()
  379. @bp.route('/model-parameters', methods=['GET'])
  380. def get_all_model_parameters():
  381. """
  382. 获取所有模型参数的API接口
  383. @return: JSON响应
  384. """
  385. Session = sessionmaker(bind=db.engine)
  386. session = Session()
  387. try:
  388. parameters = session.query(ModelParameters).all()
  389. if parameters:
  390. result = [
  391. {
  392. 'ParamID': param.ParamID,
  393. 'ModelID': param.ModelID,
  394. 'ParamName': param.ParamName,
  395. 'ParamValue': param.ParamValue
  396. }
  397. for param in parameters
  398. ]
  399. return jsonify(result)
  400. else:
  401. return jsonify({'message': '未找到任何参数'}), 404
  402. except Exception as e:
  403. logger.error(f'获取所有模型参数失败: {str(e)}')
  404. return jsonify({'error': '服务器内部错误', 'message': str(e)}), 500
  405. finally:
  406. session.close()
  407. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  408. def get_model_parameters(model_id):
  409. try:
  410. model = Models.query.filter_by(ModelID=model_id).first()
  411. if model:
  412. # 获取该模型的所有参数
  413. parameters = [
  414. {
  415. 'ParamID': param.ParamID,
  416. 'ParamName': param.ParamName,
  417. 'ParamValue': param.ParamValue
  418. }
  419. for param in model.parameters
  420. ]
  421. # 返回模型参数信息
  422. return jsonify({
  423. 'ModelID': model.ModelID,
  424. 'ModelName': model.ModelName,
  425. 'ModelType': model.ModelType,
  426. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  427. 'Description': model.Description,
  428. 'Parameters': parameters
  429. })
  430. else:
  431. return jsonify({'message': 'Model not found'}), 404
  432. except Exception as e:
  433. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  434. # 定义添加数据库记录的 API 接口
  435. @bp.route('/add_item', methods=['POST'])
  436. def add_item():
  437. """
  438. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  439. 尝试将数据插入到指定的表中,并进行字段查重。
  440. :return:
  441. """
  442. try:
  443. # 确保请求体是 JSON 格式
  444. data = request.get_json()
  445. if not data:
  446. raise ValueError("No JSON data provided")
  447. table_name = data.get('table')
  448. item_data = data.get('item')
  449. if not table_name or not item_data:
  450. return jsonify({'error': 'Missing table name or item data'}), 400
  451. # 定义各个表的字段查重规则
  452. duplicate_check_rules = {
  453. 'users': ['email', 'username'],
  454. 'products': ['product_code'],
  455. 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
  456. 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
  457. # 其他表和规则
  458. }
  459. # 获取该表的查重字段
  460. duplicate_columns = duplicate_check_rules.get(table_name)
  461. if not duplicate_columns:
  462. return jsonify({'error': 'No duplicate check rule for this table'}), 400
  463. # 动态构建查询条件,逐一检查是否有重复数据
  464. condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
  465. duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
  466. result = db.session.execute(text(duplicate_query), item_data).fetchone()
  467. if result:
  468. return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
  469. # 动态构建 SQL 语句,进行插入操作
  470. columns = ', '.join(item_data.keys())
  471. placeholders = ', '.join([f":{key}" for key in item_data.keys()])
  472. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  473. # 直接执行插入操作,无需显式的事务管理
  474. db.session.execute(text(sql), item_data)
  475. # 提交事务
  476. db.session.commit()
  477. # 返回成功响应
  478. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  479. except ValueError as e:
  480. return jsonify({'error': str(e)}), 400
  481. except KeyError as e:
  482. return jsonify({'error': f'Missing data field: {e}'}), 400
  483. except sqlite3.IntegrityError as e:
  484. return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
  485. except sqlite3.Error as e:
  486. return jsonify({'error': '数据库错误', 'details': str(e)}), 500
  487. @bp.route('/delete_item', methods=['POST'])
  488. def delete_item():
  489. """
  490. 删除数据库记录的 API 接口
  491. """
  492. data = request.get_json()
  493. table_name = data.get('table')
  494. condition = data.get('condition')
  495. # 检查表名和条件是否提供
  496. if not table_name or not condition:
  497. return jsonify({
  498. "success": False,
  499. "message": "缺少表名或条件参数"
  500. }), 400
  501. # 尝试从条件字符串中解析键和值
  502. try:
  503. key, value = condition.split('=')
  504. key = key.strip() # 去除多余的空格
  505. value = value.strip().strip("'\"") # 去除多余的空格和引号
  506. except ValueError:
  507. return jsonify({
  508. "success": False,
  509. "message": "条件格式错误,应为 'key=value'"
  510. }), 400
  511. # 准备 SQL 删除语句
  512. sql = f"DELETE FROM {table_name} WHERE {key} = :value"
  513. try:
  514. # 使用 SQLAlchemy 执行删除
  515. with db.session.begin():
  516. result = db.session.execute(text(sql), {"value": value})
  517. # 检查是否有记录被删除
  518. if result.rowcount == 0:
  519. return jsonify({
  520. "success": False,
  521. "message": "未找到符合条件的记录"
  522. }), 404
  523. return jsonify({
  524. "success": True,
  525. "message": "记录删除成功"
  526. }), 200
  527. except Exception as e:
  528. return jsonify({
  529. "success": False,
  530. "message": f"删除失败: {e}"
  531. }), 500
  532. # 定义修改数据库记录的 API 接口
  533. @bp.route('/update_item', methods=['PUT'])
  534. def update_record():
  535. """
  536. 接收 JSON 格式的请求体,包含表名和更新的数据。
  537. 尝试更新指定的记录。
  538. """
  539. data = request.get_json()
  540. # 检查必要的数据是否提供
  541. if not data or 'table' not in data or 'item' not in data:
  542. return jsonify({
  543. "success": False,
  544. "message": "请求数据不完整"
  545. }), 400
  546. table_name = data['table']
  547. item = data['item']
  548. # 假设 item 的第一个键是 ID
  549. id_key = next(iter(item.keys())) # 获取第一个键
  550. record_id = item.get(id_key)
  551. if not record_id:
  552. return jsonify({
  553. "success": False,
  554. "message": "缺少记录 ID"
  555. }), 400
  556. # 获取更新的字段和值
  557. updates = {key: value for key, value in item.items() if key != id_key}
  558. if not updates:
  559. return jsonify({
  560. "success": False,
  561. "message": "没有提供需要更新的字段"
  562. }), 400
  563. # 动态构建 SQL
  564. set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
  565. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
  566. # 添加 ID 到参数
  567. updates['id_value'] = record_id
  568. try:
  569. # 使用 SQLAlchemy 执行更新
  570. with db.session.begin():
  571. result = db.session.execute(text(sql), updates)
  572. # 检查是否有更新的记录
  573. if result.rowcount == 0:
  574. return jsonify({
  575. "success": False,
  576. "message": "未找到要更新的记录"
  577. }), 404
  578. return jsonify({
  579. "success": True,
  580. "message": "数据更新成功"
  581. }), 200
  582. except Exception as e:
  583. # 捕获所有异常并返回
  584. return jsonify({
  585. "success": False,
  586. "message": f"更新失败: {str(e)}"
  587. }), 500
  588. # 定义查询数据库记录的 API 接口
  589. @bp.route('/search/record', methods=['GET'])
  590. def sql_search():
  591. """
  592. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  593. 尝试查询指定 ID 的记录并返回结果。
  594. :return:
  595. """
  596. try:
  597. data = request.get_json()
  598. # 表名
  599. sql_table = data['table']
  600. # 要搜索的 ID
  601. Id = data['id']
  602. # 连接到数据库
  603. cur = db.cursor()
  604. # 构造查询语句
  605. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  606. # 执行查询
  607. cur.execute(sql, (Id,))
  608. # 获取查询结果
  609. rows = cur.fetchall()
  610. column_names = [desc[0] for desc in cur.description]
  611. # 检查是否有结果
  612. if not rows:
  613. return jsonify({'error': '未查找到对应数据。'}), 400
  614. # 构造响应数据
  615. results = []
  616. for row in rows:
  617. result = {column_names[i]: row[i] for i in range(len(row))}
  618. results.append(result)
  619. # 关闭游标和数据库连接
  620. cur.close()
  621. db.close()
  622. # 返回 JSON 响应
  623. return jsonify(results), 200
  624. except sqlite3.Error as e:
  625. # 如果发生数据库错误,返回错误信息
  626. return jsonify({'error': str(e)}), 400
  627. except KeyError as e:
  628. # 如果请求数据中缺少必要的键,返回错误信息
  629. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  630. # 定义提供数据库列表,用于展示表格的 API 接口
  631. @bp.route('/table', methods=['POST'])
  632. def get_table():
  633. data = request.get_json()
  634. table_name = data.get('table')
  635. if not table_name:
  636. return jsonify({'error': '需要表名'}), 400
  637. try:
  638. # 创建 sessionmaker 实例
  639. Session = sessionmaker(bind=db.engine)
  640. session = Session()
  641. # 动态获取表的元数据
  642. metadata = MetaData()
  643. table = Table(table_name, metadata, autoload_with=db.engine)
  644. # 从数据库中查询所有记录
  645. query = select(table)
  646. result = session.execute(query).fetchall()
  647. # 将结果转换为列表字典形式
  648. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  649. # 获取列名
  650. headers = [column.name for column in table.columns]
  651. return jsonify(rows=rows, headers=headers), 200
  652. except Exception as e:
  653. return jsonify({'error': str(e)}), 400
  654. finally:
  655. # 关闭 session
  656. session.close()
  657. @bp.route('/train-model-async', methods=['POST'])
  658. def train_model_async():
  659. """
  660. 异步训练模型的API接口
  661. """
  662. try:
  663. data = request.get_json()
  664. # 从请求中获取参数
  665. model_type = data.get('model_type')
  666. model_name = data.get('model_name')
  667. model_description = data.get('model_description')
  668. data_type = data.get('data_type')
  669. dataset_id = data.get('dataset_id', None)
  670. # 验证必要参数
  671. if not all([model_type, model_name, data_type]):
  672. return jsonify({
  673. 'error': 'Missing required parameters'
  674. }), 400
  675. # 如果提供了dataset_id,验证数据集是否存在
  676. if dataset_id:
  677. Session = sessionmaker(bind=db.engine)
  678. session = Session()
  679. try:
  680. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  681. if not dataset:
  682. return jsonify({
  683. 'error': f'Dataset with ID {dataset_id} not found'
  684. }), 404
  685. finally:
  686. session.close()
  687. # 启动异步任务
  688. task = train_model_task.delay(
  689. model_type=model_type,
  690. model_name=model_name,
  691. model_description=model_description,
  692. data_type=data_type,
  693. dataset_id=dataset_id
  694. )
  695. # 返回任务ID
  696. return jsonify({
  697. 'task_id': task.id,
  698. 'message': 'Model training started'
  699. }), 202
  700. except Exception as e:
  701. logging.error('Failed to start async training task:', exc_info=True)
  702. return jsonify({
  703. 'error': str(e)
  704. }), 500
  705. @bp.route('/task-status/<task_id>', methods=['GET'])
  706. def get_task_status(task_id):
  707. """
  708. 获取异步任务状态的API接口
  709. """
  710. try:
  711. task = train_model_task.AsyncResult(task_id)
  712. if task.state == 'PENDING':
  713. response = {
  714. 'state': task.state,
  715. 'status': 'Task is waiting for execution'
  716. }
  717. elif task.state == 'FAILURE':
  718. response = {
  719. 'state': task.state,
  720. 'status': 'Task failed',
  721. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  722. }
  723. elif task.state == 'SUCCESS':
  724. response = {
  725. 'state': task.state,
  726. 'status': 'Task completed successfully',
  727. 'result': task.get()
  728. }
  729. else:
  730. response = {
  731. 'state': task.state,
  732. 'status': 'Task is in progress'
  733. }
  734. return jsonify(response), 200
  735. except Exception as e:
  736. return jsonify({
  737. 'error': str(e)
  738. }), 500
  739. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  740. def delete_model_route(model_id):
  741. # 将URL参数转换为布尔值
  742. delete_dataset_param = request.args.get('delete_dataset', 'False').lower() == 'true'
  743. # 调用原始函数
  744. return delete_model(model_id, delete_dataset=delete_dataset_param)
  745. def delete_model(model_id, delete_dataset=False):
  746. """
  747. 删除指定模型的API接口
  748. @param model_id: 要删除的模型ID
  749. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  750. @return: JSON响应
  751. """
  752. Session = sessionmaker(bind=db.engine)
  753. session = Session()
  754. try:
  755. # 查询模型信息
  756. model = session.query(Models).filter_by(ModelID=model_id).first()
  757. if not model:
  758. return jsonify({'error': '未找到指定模型'}), 404
  759. dataset_id = model.DatasetID
  760. # 1. 先删除模型记录
  761. session.delete(model)
  762. session.commit()
  763. # 2. 删除模型文件
  764. model_path = model.ModelFilePath
  765. try:
  766. if os.path.exists(model_path):
  767. os.remove(model_path)
  768. else:
  769. # 如果删除文件失败,回滚数据库操作
  770. session.rollback()
  771. logger.warning(f'模型文件不存在: {model_path}')
  772. except OSError as e:
  773. # 如果删除文件失败,回滚数据库操作
  774. session.rollback()
  775. logger.error(f'删除模型文件失败: {str(e)}')
  776. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  777. # 3. 如果需要删除关联的数据集
  778. if delete_dataset and dataset_id:
  779. try:
  780. dataset_response = delete_dataset_endpoint(dataset_id)
  781. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  782. # 如果删除数据集失败,回滚之前的操作
  783. session.rollback()
  784. return jsonify({
  785. 'error': '删除关联数据集失败',
  786. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  787. }), 500
  788. except Exception as e:
  789. session.rollback()
  790. logger.error(f'删除关联数据集失败: {str(e)}')
  791. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  792. response_data = {
  793. 'message': '模型删除成功',
  794. 'deleted_files': [model_path]
  795. }
  796. if delete_dataset:
  797. response_data['dataset_info'] = {
  798. 'dataset_id': dataset_id,
  799. 'message': '关联数据集已删除'
  800. }
  801. return jsonify(response_data), 200
  802. except Exception as e:
  803. session.rollback()
  804. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  805. return jsonify({'error': str(e)}), 500
  806. finally:
  807. session.close()
  808. # 添加一个新的API端点来清空指定数据集
  809. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  810. def clear_dataset(data_type):
  811. """
  812. 清空指定类型的数据集并递增计数
  813. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  814. @return: JSON响应
  815. """
  816. # 创建 sessionmaker 实例
  817. Session = sessionmaker(bind=db.engine)
  818. session = Session()
  819. try:
  820. # 根据数据集类型选择表
  821. if data_type == 'reduce':
  822. table = CurrentReduce
  823. table_name = 'current_reduce'
  824. elif data_type == 'reflux':
  825. table = CurrentReflux
  826. table_name = 'current_reflux'
  827. else:
  828. return jsonify({'error': '无效的数据集类型'}), 400
  829. # 清空表内容
  830. session.query(table).delete()
  831. # 重置自增主键计数器
  832. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  833. session.commit()
  834. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  835. except Exception as e:
  836. session.rollback()
  837. return jsonify({'error': str(e)}), 500
  838. finally:
  839. session.close()
  840. @bp.route('/login', methods=['POST'])
  841. def login_user():
  842. # 获取前端传来的数据
  843. data = request.get_json()
  844. name = data.get('name') # 用户名
  845. password = data.get('password') # 密码
  846. logger.info(f"Login request received: name={name}")
  847. # 检查用户名和密码是否为空
  848. if not name or not password:
  849. logger.warning("用户名和密码不能为空")
  850. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  851. try:
  852. # 查询数据库验证用户名
  853. query = "SELECT * FROM users WHERE name = :name"
  854. conn = get_db()
  855. user = conn.execute(query, {"name": name}).fetchone()
  856. if not user:
  857. logger.warning(f"用户名 '{name}' 不存在")
  858. return jsonify({"success": False, "message": "用户名不存在"}), 400
  859. # 获取数据库中存储的密码(假设密码是哈希存储的)
  860. stored_password = user[2] # 假设密码存储在数据库的第三列
  861. user_id = user[0] # 假设 id 存储在数据库的第一列
  862. # 校验密码是否正确
  863. if check_password_hash(stored_password, password):
  864. logger.info(f"User '{name}' logged in successfully.")
  865. return jsonify({
  866. "success": True,
  867. "message": "登录成功",
  868. "userId": user_id # 返回用户 ID
  869. })
  870. else:
  871. logger.warning(f"Invalid password for user '{name}'")
  872. return jsonify({"success": False, "message": "用户名或密码错误"}), 400
  873. except Exception as e:
  874. # 记录错误日志并返回错误信息
  875. logger.error(f"Error during login: {e}", exc_info=True)
  876. return jsonify({"success": False, "message": "登录失败"}), 500
  877. # 更新用户信息接口
  878. @bp.route('/update_user', methods=['POST'])
  879. def update_user():
  880. # 获取前端传来的数据
  881. data = request.get_json()
  882. # 打印收到的请求数据
  883. current_app.logger.info(f"Received data: {data}")
  884. user_id = data.get('userId') # 用户ID
  885. name = data.get('name') # 用户名
  886. old_password = data.get('oldPassword') # 旧密码
  887. new_password = data.get('newPassword') # 新密码
  888. logger.info(f"Update request received: user_id={user_id}, name={name}")
  889. # 校验传入的用户名和密码是否为空
  890. if not name or not old_password:
  891. logger.warning("用户名和旧密码不能为空")
  892. return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
  893. # 新密码和旧密码不能相同
  894. if new_password and old_password == new_password:
  895. logger.warning(f"新密码与旧密码相同:{name}")
  896. return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
  897. try:
  898. # 查询数据库验证用户ID
  899. query = "SELECT * FROM users WHERE id = :user_id"
  900. conn = get_db()
  901. user = conn.execute(query, {"user_id": user_id}).fetchone()
  902. if not user:
  903. logger.warning(f"用户ID '{user_id}' 不存在")
  904. return jsonify({"success": False, "message": "用户不存在"}), 400
  905. # 获取数据库中存储的密码(假设密码是哈希存储的)
  906. stored_password = user[2] # 假设密码存储在数据库的第三列
  907. # 校验旧密码是否正确
  908. if not check_password_hash(stored_password, old_password):
  909. logger.warning(f"旧密码错误:{name}")
  910. return jsonify({"success": False, "message": "旧密码错误"}), 400
  911. # 如果新密码非空,则更新新密码
  912. if new_password:
  913. hashed_new_password = hash_password(new_password)
  914. update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
  915. conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
  916. conn.commit()
  917. logger.info(f"User ID '{user_id}' password updated successfully.")
  918. # 如果用户名发生更改,则更新用户名
  919. if name != user[1]:
  920. update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
  921. conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
  922. conn.commit()
  923. logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
  924. return jsonify({"success": True, "message": "用户信息更新成功"})
  925. except Exception as e:
  926. # 记录错误日志并返回错误信息
  927. logger.error(f"Error updating user: {e}", exc_info=True)
  928. return jsonify({"success": False, "message": "更新失败"}), 500
  929. # 注册用户
  930. @bp.route('/register', methods=['POST'])
  931. def register_user():
  932. # 获取前端传来的数据
  933. data = request.get_json()
  934. name = data.get('name') # 用户名
  935. password = data.get('password') # 密码
  936. logger.info(f"Register request received: name={name}")
  937. # 检查用户名和密码是否为空
  938. if not name or not password:
  939. logger.warning("用户名和密码不能为空")
  940. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  941. # 动态获取数据库表的列名
  942. columns = get_column_names('users')
  943. logger.info(f"Database columns for 'users' table: {columns}")
  944. # 检查前端传来的数据是否包含数据库表中所有的必填字段
  945. for column in ['name', 'password']:
  946. if column not in columns:
  947. logger.error(f"缺少必填字段:{column}")
  948. return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
  949. # 对密码进行哈希处理
  950. hashed_password = hash_password(password)
  951. logger.info(f"Password hashed for user: {name}")
  952. # 插入到数据库
  953. try:
  954. # 检查用户是否已经存在
  955. query = "SELECT * FROM users WHERE name = :name"
  956. conn = get_db()
  957. user = conn.execute(query, {"name": name}).fetchone()
  958. if user:
  959. logger.warning(f"用户名 '{name}' 已存在")
  960. return jsonify({"success": False, "message": "用户名已存在"}), 400
  961. # 向数据库插入数据
  962. query = "INSERT INTO users (name, password) VALUES (:name, :password)"
  963. conn.execute(query, {"name": name, "password": hashed_password})
  964. conn.commit()
  965. logger.info(f"User '{name}' registered successfully.")
  966. return jsonify({"success": True, "message": "注册成功"})
  967. except Exception as e:
  968. # 记录错误日志并返回错误信息
  969. logger.error(f"Error registering user: {e}", exc_info=True)
  970. return jsonify({"success": False, "message": "注册失败"}), 500
  971. def get_column_names(table_name):
  972. """
  973. 动态获取数据库表的列名。
  974. """
  975. try:
  976. conn = get_db()
  977. query = f"PRAGMA table_info({table_name});"
  978. result = conn.execute(query).fetchall()
  979. conn.close()
  980. return [row[1] for row in result] # 第二列是列名
  981. except Exception as e:
  982. logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
  983. return []
  984. # 导出数据
  985. @bp.route('/export_data', methods=['GET'])
  986. def export_data():
  987. table_name = request.args.get('table')
  988. file_format = request.args.get('format', 'excel').lower()
  989. if not table_name:
  990. return jsonify({'error': '缺少表名参数'}), 400
  991. if not table_name.isidentifier():
  992. return jsonify({'error': '无效的表名'}), 400
  993. try:
  994. conn = get_db()
  995. query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
  996. table_exists = conn.execute(query, (table_name,)).fetchone()
  997. if not table_exists:
  998. return jsonify({'error': f"表 {table_name} 不存在"}), 404
  999. query = f"SELECT * FROM {table_name};"
  1000. df = pd.read_sql(query, conn)
  1001. output = BytesIO()
  1002. if file_format == 'csv':
  1003. df.to_csv(output, index=False, encoding='utf-8')
  1004. output.seek(0)
  1005. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
  1006. elif file_format == 'excel':
  1007. df.to_excel(output, index=False, engine='openpyxl')
  1008. output.seek(0)
  1009. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
  1010. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1011. else:
  1012. return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
  1013. except Exception as e:
  1014. logger.error(f"Error in export_data: {e}", exc_info=True)
  1015. return jsonify({'error': str(e)}), 500
  1016. # 导入数据接口
  1017. @bp.route('/import_data', methods=['POST'])
  1018. def import_data():
  1019. logger.debug("Import data endpoint accessed.")
  1020. if 'file' not in request.files:
  1021. logger.error("No file in request.")
  1022. return jsonify({'success': False, 'message': '文件缺失'}), 400
  1023. file = request.files['file']
  1024. table_name = request.form.get('table')
  1025. if not table_name:
  1026. logger.error("Missing table name parameter.")
  1027. return jsonify({'success': False, 'message': '缺少表名参数'}), 400
  1028. if file.filename == '':
  1029. logger.error("No file selected.")
  1030. return jsonify({'success': False, 'message': '未选择文件'}), 400
  1031. try:
  1032. # 保存文件到临时路径
  1033. temp_path = os.path.join(current_app.config['UPLOAD_FOLDER'], secure_filename(file.filename))
  1034. file.save(temp_path)
  1035. logger.debug(f"File saved to temporary path: {temp_path}")
  1036. # 根据文件类型读取文件
  1037. if file.filename.endswith('.xlsx'):
  1038. df = pd.read_excel(temp_path)
  1039. elif file.filename.endswith('.csv'):
  1040. df = pd.read_csv(temp_path)
  1041. else:
  1042. logger.error("Unsupported file format.")
  1043. return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
  1044. # 获取数据库列名
  1045. db_columns = get_column_names(table_name)
  1046. if 'id' in db_columns:
  1047. db_columns.remove('id') # 假设 id 列是自增的,不需要处理
  1048. if not set(db_columns).issubset(set(df.columns)):
  1049. logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
  1050. return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
  1051. # 清洗数据并删除空值行
  1052. df_cleaned = df[db_columns].dropna()
  1053. # 统一数据类型,避免 int 和 float 合并问题
  1054. df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
  1055. # 获取现有的数据
  1056. conn = get_db()
  1057. with conn:
  1058. existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
  1059. # 查找重复数据
  1060. duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
  1061. # 如果有重复数据,删除它们
  1062. df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
  1063. logger.warning(f"Duplicate data detected and removed: {duplicates}")
  1064. # 获取导入前后的数据量
  1065. total_data = len(df_cleaned) + len(duplicates)
  1066. new_data = len(df_cleaned)
  1067. duplicate_data = len(duplicates)
  1068. # 导入不重复的数据
  1069. df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
  1070. logger.debug(f"Imported {new_data} new records into the database.")
  1071. # 删除临时文件
  1072. os.remove(temp_path)
  1073. logger.debug(f"Temporary file removed: {temp_path}")
  1074. # 返回结果
  1075. return jsonify({
  1076. 'success': True,
  1077. 'message': '数据导入成功',
  1078. 'total_data': total_data,
  1079. 'new_data': new_data,
  1080. 'duplicate_data': duplicate_data
  1081. }), 200
  1082. except Exception as e:
  1083. logger.error(f"Import failed: {e}", exc_info=True)
  1084. return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
  1085. # 模板下载接口
  1086. @bp.route('/download_template', methods=['GET'])
  1087. def download_template():
  1088. """
  1089. 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
  1090. """
  1091. table_name = request.args.get('table')
  1092. if not table_name:
  1093. return jsonify({'error': '表名参数缺失'}), 400
  1094. columns = get_column_names(table_name)
  1095. if not columns:
  1096. return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
  1097. # 不包括 ID 列
  1098. if 'id' in columns:
  1099. columns.remove('id')
  1100. df = pd.DataFrame(columns=columns)
  1101. file_format = request.args.get('format', 'excel').lower()
  1102. try:
  1103. if file_format == 'csv':
  1104. output = BytesIO()
  1105. df.to_csv(output, index=False, encoding='utf-8')
  1106. output.seek(0)
  1107. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
  1108. mimetype='text/csv')
  1109. else:
  1110. output = BytesIO()
  1111. df.to_excel(output, index=False, engine='openpyxl')
  1112. output.seek(0)
  1113. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
  1114. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1115. except Exception as e:
  1116. logger.error(f"Failed to generate template: {e}", exc_info=True)
  1117. return jsonify({'error': '生成模板文件失败'}), 500
  1118. @bp.route('/update-threshold', methods=['POST'])
  1119. def update_threshold():
  1120. """
  1121. 更新训练阈值的API接口
  1122. @body_param threshold: 新的阈值值(整数)
  1123. @return: JSON响应
  1124. """
  1125. try:
  1126. data = request.get_json()
  1127. new_threshold = data.get('threshold')
  1128. # 验证新阈值
  1129. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  1130. return jsonify({
  1131. 'error': '无效的阈值值,必须为正数'
  1132. }), 400
  1133. # 更新当前应用的阈值配置
  1134. current_app.config['THRESHOLD'] = int(new_threshold)
  1135. return jsonify({
  1136. 'success': True,
  1137. 'message': f'阈值已更新为 {new_threshold}',
  1138. 'new_threshold': new_threshold
  1139. })
  1140. except Exception as e:
  1141. logging.error(f"更新阈值失败: {str(e)}")
  1142. return jsonify({
  1143. 'error': f'更新阈值失败: {str(e)}'
  1144. }), 500
  1145. @bp.route('/get-threshold', methods=['GET'])
  1146. def get_threshold():
  1147. """
  1148. 获取当前训练阈值的API接口
  1149. @return: JSON响应
  1150. """
  1151. try:
  1152. current_threshold = current_app.config['THRESHOLD']
  1153. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  1154. return jsonify({
  1155. 'current_threshold': current_threshold,
  1156. 'default_threshold': default_threshold
  1157. })
  1158. except Exception as e:
  1159. logging.error(f"获取阈值失败: {str(e)}")
  1160. return jsonify({
  1161. 'error': f'获取阈值失败: {str(e)}'
  1162. }), 500
  1163. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  1164. def set_current_dataset(data_type, dataset_id):
  1165. """
  1166. 将指定数据集设置为current数据集
  1167. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1168. @param dataset_id: 要设置为current的数据集ID
  1169. @return: JSON响应
  1170. """
  1171. Session = sessionmaker(bind=db.engine)
  1172. session = Session()
  1173. try:
  1174. # 验证数据集存在且类型匹配
  1175. dataset = session.query(Datasets)\
  1176. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  1177. .first()
  1178. if not dataset:
  1179. return jsonify({
  1180. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  1181. }), 404
  1182. # 根据数据类型选择表
  1183. if data_type == 'reduce':
  1184. table = CurrentReduce
  1185. table_name = 'current_reduce'
  1186. elif data_type == 'reflux':
  1187. table = CurrentReflux
  1188. table_name = 'current_reflux'
  1189. else:
  1190. return jsonify({'error': '无效的数据集类型'}), 400
  1191. # 清空current表
  1192. session.query(table).delete()
  1193. # 重置自增主键计数器
  1194. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  1195. # 从指定数据集复制数据到current表
  1196. dataset_table_name = f"dataset_{dataset_id}"
  1197. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  1198. session.execute(copy_sql)
  1199. session.commit()
  1200. return jsonify({
  1201. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  1202. 'dataset_id': dataset_id,
  1203. 'dataset_name': dataset.Dataset_name,
  1204. 'row_count': dataset.Row_count
  1205. }), 200
  1206. except Exception as e:
  1207. session.rollback()
  1208. logger.error(f'设置current数据集失败: {str(e)}')
  1209. return jsonify({'error': str(e)}), 500
  1210. finally:
  1211. session.close()
  1212. @bp.route('/get-model-history/<string:data_type>', methods=['GET'])
  1213. def get_model_history(data_type):
  1214. """
  1215. 获取模型训练历史数据的API接口
  1216. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1217. @return: JSON响应,包含时间序列的模型性能数据
  1218. """
  1219. Session = sessionmaker(bind=db.engine)
  1220. session = Session()
  1221. try:
  1222. # 查询所有自动生成的数据集,按时间排序
  1223. datasets = session.query(Datasets).filter(
  1224. Datasets.Dataset_type == data_type,
  1225. Datasets.Dataset_description == f"Automatically generated dataset for type {data_type}"
  1226. ).order_by(Datasets.Uploaded_at).all()
  1227. history_data = []
  1228. for dataset in datasets:
  1229. # 查找对应的自动训练模型
  1230. model = session.query(Models).filter(
  1231. Models.DatasetID == dataset.Dataset_ID,
  1232. Models.Model_name.like(f'auto_trained_{data_type}_%')
  1233. ).first()
  1234. if model and model.Performance_score is not None:
  1235. # 直接使用数据库中的时间,不进行格式化(保持与created_at相同的时区)
  1236. created_at = model.Created_at.isoformat() if model.Created_at else None
  1237. history_data.append({
  1238. 'dataset_id': dataset.Dataset_ID,
  1239. 'row_count': dataset.Row_count,
  1240. 'model_id': model.ModelID,
  1241. 'model_name': model.Model_name,
  1242. 'performance_score': float(model.Performance_score),
  1243. 'timestamp': created_at
  1244. })
  1245. # 按时间戳排序
  1246. history_data.sort(key=lambda x: x['timestamp'] if x['timestamp'] else '')
  1247. # 构建返回数据,分离各个指标序列便于前端绘图
  1248. response_data = {
  1249. 'data_type': data_type,
  1250. 'timestamps': [item['timestamp'] for item in history_data],
  1251. 'row_counts': [item['row_count'] for item in history_data],
  1252. 'performance_scores': [item['performance_score'] for item in history_data],
  1253. 'model_details': history_data # 保留完整数据供前端使用
  1254. }
  1255. return jsonify(response_data), 200
  1256. except Exception as e:
  1257. logger.error(f'获取模型历史数据失败: {str(e)}', exc_info=True)
  1258. return jsonify({'error': str(e)}), 500
  1259. finally:
  1260. session.close()
  1261. @bp.route('/batch-delete-datasets', methods=['POST'])
  1262. def batch_delete_datasets():
  1263. """
  1264. 批量删除数据集的API接口
  1265. @body_param dataset_ids: 要删除的数据集ID列表
  1266. @return: JSON响应
  1267. """
  1268. try:
  1269. data = request.get_json()
  1270. dataset_ids = data.get('dataset_ids', [])
  1271. if not dataset_ids:
  1272. return jsonify({'error': '未提供数据集ID列表'}), 400
  1273. results = {
  1274. 'success': [],
  1275. 'failed': [],
  1276. 'protected': [] # 被模型使用的数据集
  1277. }
  1278. for dataset_id in dataset_ids:
  1279. try:
  1280. # 调用单个删除接口
  1281. response = delete_dataset_endpoint(dataset_id)
  1282. # 解析响应
  1283. if response[1] == 200:
  1284. results['success'].append(dataset_id)
  1285. elif response[1] == 400 and 'models' in response[0].json:
  1286. # 数据集被模型保护
  1287. results['protected'].append({
  1288. 'id': dataset_id,
  1289. 'models': response[0].json['models']
  1290. })
  1291. else:
  1292. results['failed'].append({
  1293. 'id': dataset_id,
  1294. 'reason': response[0].json.get('error', '删除失败')
  1295. })
  1296. except Exception as e:
  1297. logger.error(f'删除数据集 {dataset_id} 失败: {str(e)}')
  1298. results['failed'].append({
  1299. 'id': dataset_id,
  1300. 'reason': str(e)
  1301. })
  1302. # 构建响应消息
  1303. message = f"成功删除 {len(results['success'])} 个数据集"
  1304. if results['protected']:
  1305. message += f", {len(results['protected'])} 个数据集被保护"
  1306. if results['failed']:
  1307. message += f", {len(results['failed'])} 个数据集删除失败"
  1308. return jsonify({
  1309. 'message': message,
  1310. 'results': results
  1311. }), 200
  1312. except Exception as e:
  1313. logger.error(f'批量删除数据集失败: {str(e)}')
  1314. return jsonify({'error': str(e)}), 500
  1315. @bp.route('/batch-delete-models', methods=['POST'])
  1316. def batch_delete_models():
  1317. """
  1318. 批量删除模型的API接口
  1319. @body_param model_ids: 要删除的模型ID列表
  1320. @query_param delete_datasets: 布尔值,是否同时删除关联的数据集,默认为False
  1321. @return: JSON响应
  1322. """
  1323. try:
  1324. data = request.get_json()
  1325. model_ids = data.get('model_ids', [])
  1326. delete_datasets = request.args.get('delete_datasets', 'false').lower() == 'true'
  1327. if not model_ids:
  1328. return jsonify({'error': '未提供模型ID列表'}), 400
  1329. results = {
  1330. 'success': [],
  1331. 'failed': [],
  1332. 'datasets_deleted': [] # 如果delete_datasets为true,记录被删除的数据集
  1333. }
  1334. for model_id in model_ids:
  1335. try:
  1336. # 调用单个删除接口
  1337. response = delete_model(model_id, delete_dataset=delete_datasets)
  1338. # 解析响应
  1339. if response[1] == 200:
  1340. results['success'].append(model_id)
  1341. # 如果删除了关联数据集,记录数据集ID
  1342. if 'dataset_info' in response[0].json:
  1343. results['datasets_deleted'].append(
  1344. response[0].json['dataset_info']['dataset_id']
  1345. )
  1346. else:
  1347. results['failed'].append({
  1348. 'id': model_id,
  1349. 'reason': response[0].json.get('error', '删除失败')
  1350. })
  1351. except Exception as e:
  1352. logger.error(f'删除模型 {model_id} 失败: {str(e)}')
  1353. results['failed'].append({
  1354. 'id': model_id,
  1355. 'reason': str(e)
  1356. })
  1357. # 构建响应消息
  1358. message = f"成功删除 {len(results['success'])} 个模型"
  1359. if results['datasets_deleted']:
  1360. message += f", {len(results['datasets_deleted'])} 个关联数据集"
  1361. if results['failed']:
  1362. message += f", {len(results['failed'])} 个模型删除失败"
  1363. return jsonify({
  1364. 'message': message,
  1365. 'results': results
  1366. }), 200
  1367. except Exception as e:
  1368. logger.error(f'批量删除模型失败: {str(e)}')
  1369. return jsonify({'error': str(e)}), 500
  1370. @bp.route('/kriging_interpolation', methods=['POST'])
  1371. def kriging_interpolation():
  1372. try:
  1373. data = request.get_json()
  1374. required = ['file_name', 'emission_column', 'points']
  1375. if not all(k in data for k in required):
  1376. return jsonify({"error": "Missing parameters"}), 400
  1377. # 添加坐标顺序验证
  1378. points = data['points']
  1379. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  1380. return jsonify({"error": "Invalid points format"}), 400
  1381. result = create_kriging(
  1382. data['file_name'],
  1383. data['emission_column'],
  1384. data['points']
  1385. )
  1386. return jsonify(result)
  1387. except Exception as e:
  1388. return jsonify({"error": str(e)}), 500
  1389. @bp.route('/model-scatter-data/<int:model_id>', methods=['GET'])
  1390. def get_model_scatter_data(model_id):
  1391. """
  1392. 获取指定模型的散点图数据(真实值vs预测值)
  1393. @param model_id: 模型ID
  1394. @return: JSON响应,包含散点图数据
  1395. """
  1396. Session = sessionmaker(bind=db.engine)
  1397. session = Session()
  1398. try:
  1399. # 查询模型信息
  1400. model = session.query(Models).filter_by(ModelID=model_id).first()
  1401. if not model:
  1402. return jsonify({'error': '未找到指定模型'}), 404
  1403. # 加载模型
  1404. with open(model.ModelFilePath, 'rb') as f:
  1405. ML_model = pickle.load(f)
  1406. # 根据数据类型加载测试数据
  1407. if model.Data_type == 'reflux':
  1408. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  1409. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  1410. elif model.Data_type == 'reduce':
  1411. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  1412. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  1413. else:
  1414. return jsonify({'error': '不支持的数据类型'}), 400
  1415. # 获取预测值
  1416. y_pred = ML_model.predict(X_test)
  1417. # 生成散点图数据
  1418. scatter_data = [
  1419. [float(true), float(pred)]
  1420. for true, pred in zip(Y_test.iloc[:, 0], y_pred)
  1421. ]
  1422. # 计算R²分数
  1423. r2 = r2_score(Y_test, y_pred)
  1424. # 获取数据范围,用于绘制对角线
  1425. y_min = min(min(Y_test.iloc[:, 0]), min(y_pred))
  1426. y_max = max(max(Y_test.iloc[:, 0]), max(y_pred))
  1427. return jsonify({
  1428. 'scatter_data': scatter_data,
  1429. 'r2_score': float(r2),
  1430. 'y_range': [float(y_min), float(y_max)],
  1431. 'model_name': model.Model_name,
  1432. 'model_type': model.Model_type
  1433. }), 200
  1434. except Exception as e:
  1435. logger.error(f'获取模型散点图数据失败: {str(e)}', exc_info=True)
  1436. return jsonify({'error': f'获取数据失败: {str(e)}'}), 500
  1437. finally:
  1438. session.close()