routes.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424
  1. import sqlite3
  2. from flask import Blueprint, request, jsonify, current_app
  3. from werkzeug.security import generate_password_hash, check_password_hash
  4. from werkzeug.utils import secure_filename
  5. from io import BytesIO
  6. from .model import predict, train_and_save_model, calculate_model_score
  7. import pandas as pd
  8. from . import db # 从 app 包导入 db 实例
  9. from sqlalchemy.engine.reflection import Inspector
  10. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  11. import os
  12. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  13. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  14. predict_to_Q, Q_to_t_ha, create_kriging
  15. from sqlalchemy.orm import sessionmaker
  16. import logging
  17. from sqlalchemy import text, select, MetaData, Table, func
  18. from .tasks import train_model_task
  19. from flask import send_file
  20. # 配置日志
  21. logging.basicConfig(level=logging.DEBUG)
  22. logger = logging.getLogger(__name__)
  23. # 创建蓝图 (Blueprint),用于分离路由
  24. bp = Blueprint('routes', __name__)
  25. # 密码加密
  26. def hash_password(password):
  27. return generate_password_hash(password)
  28. def get_db():
  29. """ 获取数据库连接 """
  30. return sqlite3.connect(bp.config['DATABASE'])
  31. # 添加一个新的辅助函数来检查数据集大小并触发训练
  32. def check_and_trigger_training(session, dataset_type, dataset_df):
  33. """
  34. 检查当前数据集大小是否跨越新的阈值点并触发训练
  35. Args:
  36. session: 数据库会话
  37. dataset_type: 数据集类型 ('reduce' 或 'reflux')
  38. dataset_df: 数据集 DataFrame
  39. Returns:
  40. tuple: (是否触发训练, 任务ID)
  41. """
  42. try:
  43. # 根据数据集类型选择表
  44. table = CurrentReduce if dataset_type == 'reduce' else CurrentReflux
  45. # 获取当前记录数
  46. current_count = session.query(func.count()).select_from(table).scalar()
  47. # 获取新增的记录数(从request.files中获取的DataFrame长度)
  48. new_records = len(dataset_df) # 需要从上层函数传入
  49. # 计算新增数据前的记录数
  50. previous_count = current_count - new_records
  51. # 设置阈值
  52. THRESHOLD = current_app.config['THRESHOLD']
  53. # 计算上一个阈值点(基于新增前的数据量)
  54. last_threshold = previous_count // THRESHOLD * THRESHOLD
  55. # 计算当前所在阈值点
  56. current_threshold = current_count // THRESHOLD * THRESHOLD
  57. # 检查是否跨越了新的阈值点
  58. if current_threshold > last_threshold and current_count >= THRESHOLD:
  59. # 触发异步训练任务
  60. task = train_model_task.delay(
  61. model_type=current_app.config['DEFAULT_MODEL_TYPE'],
  62. model_name=f'auto_trained_{dataset_type}_{current_threshold}',
  63. model_description=f'Auto trained model at {current_threshold} records threshold',
  64. data_type=dataset_type
  65. )
  66. return True, task.id
  67. return False, None
  68. except Exception as e:
  69. logging.error(f"检查并触发训练失败: {str(e)}")
  70. return False, None
  71. @bp.route('/upload-dataset', methods=['POST'])
  72. def upload_dataset():
  73. # 创建 session
  74. Session = sessionmaker(bind=db.engine)
  75. session = Session()
  76. try:
  77. if 'file' not in request.files:
  78. return jsonify({'error': 'No file part'}), 400
  79. file = request.files['file']
  80. if file.filename == '' or not allowed_file(file.filename):
  81. return jsonify({'error': 'No selected file or invalid file type'}), 400
  82. dataset_name = request.form.get('dataset_name')
  83. dataset_description = request.form.get('dataset_description', 'No description provided')
  84. dataset_type = request.form.get('dataset_type')
  85. if not dataset_type:
  86. return jsonify({'error': 'Dataset type is required'}), 400
  87. new_dataset = Datasets(
  88. Dataset_name=dataset_name,
  89. Dataset_description=dataset_description,
  90. Row_count=0,
  91. Status='Datasets_upgraded',
  92. Dataset_type=dataset_type
  93. )
  94. session.add(new_dataset)
  95. session.commit()
  96. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  97. upload_folder = current_app.config['UPLOAD_FOLDER']
  98. file_path = os.path.join(upload_folder, unique_filename)
  99. file.save(file_path)
  100. dataset_df = pd.read_excel(file_path)
  101. new_dataset.Row_count = len(dataset_df)
  102. new_dataset.Status = 'excel_file_saved success'
  103. session.commit()
  104. # 处理列名
  105. dataset_df = clean_column_names(dataset_df)
  106. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  107. column_types = infer_column_types(dataset_df)
  108. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  109. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  110. # 根据 dataset_type 决定插入到哪个已有表
  111. if dataset_type == 'reduce':
  112. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  113. elif dataset_type == 'reflux':
  114. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  115. session.commit()
  116. # 在完成数据插入后,检查是否需要触发训练
  117. training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
  118. response_data = {
  119. 'message': f'Dataset {dataset_name} uploaded successfully!',
  120. 'dataset_id': new_dataset.Dataset_ID,
  121. 'filename': unique_filename,
  122. 'training_triggered': training_triggered
  123. }
  124. if training_triggered:
  125. response_data['task_id'] = task_id
  126. response_data['message'] += ' Auto-training has been triggered.'
  127. return jsonify(response_data), 201
  128. except Exception as e:
  129. session.rollback()
  130. logging.error('Failed to process the dataset upload:', exc_info=True)
  131. return jsonify({'error': str(e)}), 500
  132. finally:
  133. # 确保 session 总是被关闭
  134. if session:
  135. session.close()
  136. @bp.route('/train-and-save-model', methods=['POST'])
  137. def train_and_save_model_endpoint():
  138. # 创建 sessionmaker 实例
  139. Session = sessionmaker(bind=db.engine)
  140. session = Session()
  141. data = request.get_json()
  142. # 从请求中解析参数
  143. model_type = data.get('model_type')
  144. model_name = data.get('model_name')
  145. model_description = data.get('model_description')
  146. data_type = data.get('data_type')
  147. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  148. try:
  149. # 调用训练和保存模型的函数
  150. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  151. model_id = result[1] if result else None
  152. # 计算模型评分
  153. if model_id:
  154. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  155. if model_info:
  156. score = calculate_model_score(model_info)
  157. # 更新模型评分
  158. model_info.Performance_score = score
  159. session.commit()
  160. result = {'model_id': model_id, 'model_score': score}
  161. # 返回成功响应
  162. return jsonify({
  163. 'message': 'Model trained and saved successfully',
  164. 'result': result
  165. }), 200
  166. except Exception as e:
  167. session.rollback()
  168. logging.error('Failed to process the model training:', exc_info=True)
  169. return jsonify({
  170. 'error': 'Failed to train and save model',
  171. 'message': str(e)
  172. }), 500
  173. finally:
  174. session.close()
  175. @bp.route('/predict', methods=['POST'])
  176. def predict_route():
  177. # 创建 sessionmaker 实例
  178. Session = sessionmaker(bind=db.engine)
  179. session = Session()
  180. try:
  181. data = request.get_json()
  182. model_id = data.get('model_id') # 提取模型名称
  183. parameters = data.get('parameters', {}) # 提取所有变量
  184. # 根据model_id获取模型Data_type
  185. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  186. if not model_info:
  187. return jsonify({'error': 'Model not found'}), 404
  188. data_type = model_info.Data_type
  189. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  190. # 如果为reduce,则不需要传入target_ph
  191. if data_type == 'reduce':
  192. # 获取传入的init_ph、target_ph参数
  193. init_ph = float(parameters.get('init_pH', 0.0)) # 默认值为0.0,防止None导致错误
  194. target_ph = float(parameters.get('target_pH', 0.0)) # 默认值为0.0,防止None导致错误
  195. # 从输入数据中删除'target_pH'列
  196. input_data = input_data.drop('target_pH', axis=1, errors='ignore') # 使用errors='ignore'防止列不存在时出错
  197. input_data_rename = rename_columns_for_model_predict(input_data, data_type) # 重命名列名以匹配模型字段
  198. predictions = predict(session, input_data_rename, model_id) # 调用预测函数
  199. if data_type == 'reduce':
  200. predictions = predictions[0]
  201. # 将预测结果转换为Q
  202. Q = predict_to_Q(predictions, init_ph, target_ph)
  203. predictions = Q_to_t_ha(Q) # 将Q转换为t/ha
  204. print(predictions)
  205. return jsonify({'result': predictions}), 200
  206. except Exception as e:
  207. logging.error('Failed to predict:', exc_info=True)
  208. return jsonify({'error': str(e)}), 400
  209. # 为指定模型计算评分Performance_score,需要提供model_id
  210. @bp.route('/score-model/<int:model_id>', methods=['POST'])
  211. def score_model(model_id):
  212. # 创建 sessionmaker 实例
  213. Session = sessionmaker(bind=db.engine)
  214. session = Session()
  215. try:
  216. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  217. if not model_info:
  218. return jsonify({'error': 'Model not found'}), 404
  219. # 计算模型评分
  220. score = calculate_model_score(model_info)
  221. # 更新模型记录中的评分
  222. model_info.Performance_score = score
  223. session.commit()
  224. return jsonify({'message': 'Model scored successfully', 'score': score}), 200
  225. except Exception as e:
  226. logging.error('Failed to process the dataset upload:', exc_info=True)
  227. return jsonify({'error': str(e)}), 400
  228. finally:
  229. session.close()
  230. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  231. def delete_dataset_endpoint(dataset_id):
  232. """
  233. 删除数据集的API接口
  234. @param dataset_id: 要删除的数据集ID
  235. @return: JSON响应
  236. """
  237. # 创建 sessionmaker 实例
  238. Session = sessionmaker(bind=db.engine)
  239. session = Session()
  240. try:
  241. # 查询数据集
  242. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  243. if not dataset:
  244. return jsonify({'error': '未找到数据集'}), 404
  245. # 检查是否有模型使用了该数据集
  246. models_using_dataset = session.query(Models).filter_by(DatasetID=dataset_id).all()
  247. if models_using_dataset:
  248. models_info = [{'ModelID': model.ModelID, 'Model_name': model.Model_name} for model in models_using_dataset]
  249. return jsonify({
  250. 'error': '无法删除数据集,因为以下模型正在使用它',
  251. 'models': models_info
  252. }), 400
  253. # 删除Excel文件
  254. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  255. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  256. if os.path.exists(file_path):
  257. try:
  258. os.remove(file_path)
  259. except OSError as e:
  260. logger.error(f'删除文件失败: {str(e)}')
  261. return jsonify({'error': f'删除文件失败: {str(e)}'}), 500
  262. # 删除数据表
  263. table_name = f"dataset_{dataset.Dataset_ID}"
  264. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  265. # 删除数据集记录
  266. session.delete(dataset)
  267. session.commit()
  268. return jsonify({
  269. 'message': '数据集删除成功',
  270. 'deleted_files': [filename]
  271. }), 200
  272. except Exception as e:
  273. session.rollback()
  274. logger.error(f'删除数据集 {dataset_id} 失败:', exc_info=True)
  275. return jsonify({'error': str(e)}), 500
  276. finally:
  277. session.close()
  278. @bp.route('/tables', methods=['GET'])
  279. def list_tables():
  280. engine = db.engine # 使用 db 实例的 engine
  281. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  282. table_names = inspector.get_table_names() # 获取所有表名
  283. return jsonify(table_names) # 以 JSON 形式返回表名列表
  284. @bp.route('/models/<int:model_id>', methods=['GET'])
  285. def get_model(model_id):
  286. try:
  287. model = Models.query.filter_by(ModelID=model_id).first()
  288. if model:
  289. return jsonify({
  290. 'ModelID': model.ModelID,
  291. 'ModelName': model.ModelName,
  292. 'ModelType': model.ModelType,
  293. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  294. 'Description': model.Description
  295. })
  296. else:
  297. return jsonify({'message': 'Model not found'}), 404
  298. except Exception as e:
  299. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  300. @bp.route('/models', methods=['GET'])
  301. def get_all_models():
  302. try:
  303. models = Models.query.all() # 获取所有模型数据
  304. if models:
  305. result = [
  306. {
  307. 'ModelID': model.ModelID,
  308. 'ModelName': model.ModelName,
  309. 'ModelType': model.ModelType,
  310. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  311. 'Description': model.Description
  312. }
  313. for model in models
  314. ]
  315. return jsonify(result)
  316. else:
  317. return jsonify({'message': 'No models found'}), 404
  318. except Exception as e:
  319. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  320. @bp.route('/model-parameters', methods=['GET'])
  321. def get_all_model_parameters():
  322. try:
  323. parameters = ModelParameters.query.all() # 获取所有参数数据
  324. if parameters:
  325. result = [
  326. {
  327. 'ParamID': param.ParamID,
  328. 'ModelID': param.ModelID,
  329. 'ParamName': param.ParamName,
  330. 'ParamValue': param.ParamValue
  331. }
  332. for param in parameters
  333. ]
  334. return jsonify(result)
  335. else:
  336. return jsonify({'message': 'No parameters found'}), 404
  337. except Exception as e:
  338. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  339. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  340. def get_model_parameters(model_id):
  341. try:
  342. model = Models.query.filter_by(ModelID=model_id).first()
  343. if model:
  344. # 获取该模型的所有参数
  345. parameters = [
  346. {
  347. 'ParamID': param.ParamID,
  348. 'ParamName': param.ParamName,
  349. 'ParamValue': param.ParamValue
  350. }
  351. for param in model.parameters
  352. ]
  353. # 返回模型参数信息
  354. return jsonify({
  355. 'ModelID': model.ModelID,
  356. 'ModelName': model.ModelName,
  357. 'ModelType': model.ModelType,
  358. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  359. 'Description': model.Description,
  360. 'Parameters': parameters
  361. })
  362. else:
  363. return jsonify({'message': 'Model not found'}), 404
  364. except Exception as e:
  365. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  366. # 定义添加数据库记录的 API 接口
  367. @bp.route('/add_item', methods=['POST'])
  368. def add_item():
  369. """
  370. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  371. 尝试将数据插入到指定的表中,并进行字段查重。
  372. :return:
  373. """
  374. try:
  375. # 确保请求体是 JSON 格式
  376. data = request.get_json()
  377. if not data:
  378. raise ValueError("No JSON data provided")
  379. table_name = data.get('table')
  380. item_data = data.get('item')
  381. if not table_name or not item_data:
  382. return jsonify({'error': 'Missing table name or item data'}), 400
  383. # 定义各个表的字段查重规则
  384. duplicate_check_rules = {
  385. 'users': ['email', 'username'],
  386. 'products': ['product_code'],
  387. 'current_reduce': [ 'Q_over_b', 'pH', 'OM', 'CL', 'H', 'Al'],
  388. 'current_reflux': ['OM', 'CL', 'CEC', 'H_plus', 'N', 'Al3_plus', 'Delta_pH'],
  389. # 其他表和规则
  390. }
  391. # 获取该表的查重字段
  392. duplicate_columns = duplicate_check_rules.get(table_name)
  393. if not duplicate_columns:
  394. return jsonify({'error': 'No duplicate check rule for this table'}), 400
  395. # 动态构建查询条件,逐一检查是否有重复数据
  396. condition = ' AND '.join([f"{column} = :{column}" for column in duplicate_columns])
  397. duplicate_query = f"SELECT 1 FROM {table_name} WHERE {condition} LIMIT 1"
  398. result = db.session.execute(text(duplicate_query), item_data).fetchone()
  399. if result:
  400. return jsonify({'error': '重复数据,已有相同的数据项存在。'}), 409
  401. # 动态构建 SQL 语句,进行插入操作
  402. columns = ', '.join(item_data.keys())
  403. placeholders = ', '.join([f":{key}" for key in item_data.keys()])
  404. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  405. # 直接执行插入操作,无需显式的事务管理
  406. db.session.execute(text(sql), item_data)
  407. # 提交事务
  408. db.session.commit()
  409. # 返回成功响应
  410. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  411. except ValueError as e:
  412. return jsonify({'error': str(e)}), 400
  413. except KeyError as e:
  414. return jsonify({'error': f'Missing data field: {e}'}), 400
  415. except sqlite3.IntegrityError as e:
  416. return jsonify({'error': '数据库完整性错误', 'details': str(e)}), 409
  417. except sqlite3.Error as e:
  418. return jsonify({'error': '数据库错误', 'details': str(e)}), 500
  419. @bp.route('/delete_item', methods=['POST'])
  420. def delete_item():
  421. """
  422. 删除数据库记录的 API 接口
  423. """
  424. data = request.get_json()
  425. table_name = data.get('table')
  426. condition = data.get('condition')
  427. # 检查表名和条件是否提供
  428. if not table_name or not condition:
  429. return jsonify({
  430. "success": False,
  431. "message": "缺少表名或条件参数"
  432. }), 400
  433. # 尝试从条件字符串中解析键和值
  434. try:
  435. key, value = condition.split('=')
  436. key = key.strip() # 去除多余的空格
  437. value = value.strip().strip("'\"") # 去除多余的空格和引号
  438. except ValueError:
  439. return jsonify({
  440. "success": False,
  441. "message": "条件格式错误,应为 'key=value'"
  442. }), 400
  443. # 准备 SQL 删除语句
  444. sql = f"DELETE FROM {table_name} WHERE {key} = :value"
  445. try:
  446. # 使用 SQLAlchemy 执行删除
  447. with db.session.begin():
  448. result = db.session.execute(text(sql), {"value": value})
  449. # 检查是否有记录被删除
  450. if result.rowcount == 0:
  451. return jsonify({
  452. "success": False,
  453. "message": "未找到符合条件的记录"
  454. }), 404
  455. return jsonify({
  456. "success": True,
  457. "message": "记录删除成功"
  458. }), 200
  459. except Exception as e:
  460. return jsonify({
  461. "success": False,
  462. "message": f"删除失败: {e}"
  463. }), 500
  464. # 定义修改数据库记录的 API 接口
  465. @bp.route('/update_item', methods=['PUT'])
  466. def update_record():
  467. """
  468. 接收 JSON 格式的请求体,包含表名和更新的数据。
  469. 尝试更新指定的记录。
  470. """
  471. data = request.get_json()
  472. # 检查必要的数据是否提供
  473. if not data or 'table' not in data or 'item' not in data:
  474. return jsonify({
  475. "success": False,
  476. "message": "请求数据不完整"
  477. }), 400
  478. table_name = data['table']
  479. item = data['item']
  480. # 假设 item 的第一个键是 ID
  481. id_key = next(iter(item.keys())) # 获取第一个键
  482. record_id = item.get(id_key)
  483. if not record_id:
  484. return jsonify({
  485. "success": False,
  486. "message": "缺少记录 ID"
  487. }), 400
  488. # 获取更新的字段和值
  489. updates = {key: value for key, value in item.items() if key != id_key}
  490. if not updates:
  491. return jsonify({
  492. "success": False,
  493. "message": "没有提供需要更新的字段"
  494. }), 400
  495. # 动态构建 SQL
  496. set_clause = ', '.join([f"{key} = :{key}" for key in updates.keys()])
  497. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = :id_value"
  498. # 添加 ID 到参数
  499. updates['id_value'] = record_id
  500. try:
  501. # 使用 SQLAlchemy 执行更新
  502. with db.session.begin():
  503. result = db.session.execute(text(sql), updates)
  504. # 检查是否有更新的记录
  505. if result.rowcount == 0:
  506. return jsonify({
  507. "success": False,
  508. "message": "未找到要更新的记录"
  509. }), 404
  510. return jsonify({
  511. "success": True,
  512. "message": "数据更新成功"
  513. }), 200
  514. except Exception as e:
  515. # 捕获所有异常并返回
  516. return jsonify({
  517. "success": False,
  518. "message": f"更新失败: {str(e)}"
  519. }), 500
  520. # 定义查询数据库记录的 API 接口
  521. @bp.route('/search/record', methods=['GET'])
  522. def sql_search():
  523. """
  524. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  525. 尝试查询指定 ID 的记录并返回结果。
  526. :return:
  527. """
  528. try:
  529. data = request.get_json()
  530. # 表名
  531. sql_table = data['table']
  532. # 要搜索的 ID
  533. Id = data['id']
  534. # 连接到数据库
  535. cur = db.cursor()
  536. # 构造查询语句
  537. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  538. # 执行查询
  539. cur.execute(sql, (Id,))
  540. # 获取查询结果
  541. rows = cur.fetchall()
  542. column_names = [desc[0] for desc in cur.description]
  543. # 检查是否有结果
  544. if not rows:
  545. return jsonify({'error': '未查找到对应数据。'}), 400
  546. # 构造响应数据
  547. results = []
  548. for row in rows:
  549. result = {column_names[i]: row[i] for i in range(len(row))}
  550. results.append(result)
  551. # 关闭游标和数据库连接
  552. cur.close()
  553. db.close()
  554. # 返回 JSON 响应
  555. return jsonify(results), 200
  556. except sqlite3.Error as e:
  557. # 如果发生数据库错误,返回错误信息
  558. return jsonify({'error': str(e)}), 400
  559. except KeyError as e:
  560. # 如果请求数据中缺少必要的键,返回错误信息
  561. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  562. # 定义提供数据库列表,用于展示表格的 API 接口
  563. @bp.route('/table', methods=['POST'])
  564. def get_table():
  565. data = request.get_json()
  566. table_name = data.get('table')
  567. if not table_name:
  568. return jsonify({'error': '需要表名'}), 400
  569. try:
  570. # 创建 sessionmaker 实例
  571. Session = sessionmaker(bind=db.engine)
  572. session = Session()
  573. # 动态获取表的元数据
  574. metadata = MetaData()
  575. table = Table(table_name, metadata, autoload_with=db.engine)
  576. # 从数据库中查询所有记录
  577. query = select(table)
  578. result = session.execute(query).fetchall()
  579. # 将结果转换为列表字典形式
  580. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  581. # 获取列名
  582. headers = [column.name for column in table.columns]
  583. return jsonify(rows=rows, headers=headers), 200
  584. except Exception as e:
  585. return jsonify({'error': str(e)}), 400
  586. finally:
  587. # 关闭 session
  588. session.close()
  589. @bp.route('/train-model-async', methods=['POST'])
  590. def train_model_async():
  591. """
  592. 异步训练模型的API接口
  593. """
  594. try:
  595. data = request.get_json()
  596. # 从请求中获取参数
  597. model_type = data.get('model_type')
  598. model_name = data.get('model_name')
  599. model_description = data.get('model_description')
  600. data_type = data.get('data_type')
  601. dataset_id = data.get('dataset_id', None)
  602. # 验证必要参数
  603. if not all([model_type, model_name, data_type]):
  604. return jsonify({
  605. 'error': 'Missing required parameters'
  606. }), 400
  607. # 如果提供了dataset_id,验证数据集是否存在
  608. if dataset_id:
  609. Session = sessionmaker(bind=db.engine)
  610. session = Session()
  611. try:
  612. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  613. if not dataset:
  614. return jsonify({
  615. 'error': f'Dataset with ID {dataset_id} not found'
  616. }), 404
  617. finally:
  618. session.close()
  619. # 启动异步任务
  620. task = train_model_task.delay(
  621. model_type=model_type,
  622. model_name=model_name,
  623. model_description=model_description,
  624. data_type=data_type,
  625. dataset_id=dataset_id
  626. )
  627. # 返回任务ID
  628. return jsonify({
  629. 'task_id': task.id,
  630. 'message': 'Model training started'
  631. }), 202
  632. except Exception as e:
  633. logging.error('Failed to start async training task:', exc_info=True)
  634. return jsonify({
  635. 'error': str(e)
  636. }), 500
  637. @bp.route('/task-status/<task_id>', methods=['GET'])
  638. def get_task_status(task_id):
  639. """
  640. 获取异步任务状态的API接口
  641. """
  642. try:
  643. task = train_model_task.AsyncResult(task_id)
  644. if task.state == 'PENDING':
  645. response = {
  646. 'state': task.state,
  647. 'status': 'Task is waiting for execution'
  648. }
  649. elif task.state == 'FAILURE':
  650. response = {
  651. 'state': task.state,
  652. 'status': 'Task failed',
  653. 'error': task.info.get('error') if isinstance(task.info, dict) else str(task.info)
  654. }
  655. elif task.state == 'SUCCESS':
  656. response = {
  657. 'state': task.state,
  658. 'status': 'Task completed successfully',
  659. 'result': task.get()
  660. }
  661. else:
  662. response = {
  663. 'state': task.state,
  664. 'status': 'Task is in progress'
  665. }
  666. return jsonify(response), 200
  667. except Exception as e:
  668. return jsonify({
  669. 'error': str(e)
  670. }), 500
  671. @bp.route('/delete-model/<int:model_id>', methods=['DELETE'])
  672. def delete_model(model_id):
  673. """
  674. 删除指定模型的API接口
  675. @param model_id: 要删除的模型ID
  676. @query_param delete_dataset: 布尔值,是否同时删除关联的数据集,默认为False
  677. @return: JSON响应
  678. """
  679. Session = sessionmaker(bind=db.engine)
  680. session = Session()
  681. try:
  682. # 查询模型信息
  683. model = session.query(Models).filter_by(ModelID=model_id).first()
  684. if not model:
  685. return jsonify({'error': '未找到指定模型'}), 404
  686. dataset_id = model.DatasetID
  687. # 1. 先删除模型记录
  688. session.delete(model)
  689. session.commit()
  690. # 2. 删除模型文件
  691. model_file = f"rf_model_{model_id}.pkl"
  692. model_path = os.path.join(current_app.config['MODEL_SAVE_PATH'], model_file)
  693. if os.path.exists(model_path):
  694. try:
  695. os.remove(model_path)
  696. except OSError as e:
  697. # 如果删除文件失败,回滚数据库操作
  698. session.rollback()
  699. logger.error(f'删除模型文件失败: {str(e)}')
  700. return jsonify({'error': f'删除模型文件失败: {str(e)}'}), 500
  701. # 3. 如果需要删除关联的数据集
  702. delete_dataset = request.args.get('delete_dataset', 'false').lower() == 'true'
  703. if delete_dataset and dataset_id:
  704. try:
  705. dataset_response = delete_dataset_endpoint(dataset_id)
  706. if not isinstance(dataset_response, tuple) or dataset_response[1] != 200:
  707. # 如果删除数据集失败,回滚之前的操作
  708. session.rollback()
  709. return jsonify({
  710. 'error': '删除关联数据集失败',
  711. 'dataset_error': dataset_response[0].get_json() if hasattr(dataset_response[0], 'get_json') else str(dataset_response[0])
  712. }), 500
  713. except Exception as e:
  714. session.rollback()
  715. logger.error(f'删除关联数据集失败: {str(e)}')
  716. return jsonify({'error': f'删除关联数据集失败: {str(e)}'}), 500
  717. response_data = {
  718. 'message': '模型删除成功',
  719. 'deleted_files': [model_file]
  720. }
  721. if delete_dataset:
  722. response_data['dataset_info'] = {
  723. 'dataset_id': dataset_id,
  724. 'message': '关联数据集已删除'
  725. }
  726. return jsonify(response_data), 200
  727. except Exception as e:
  728. session.rollback()
  729. logger.error(f'删除模型 {model_id} 失败:', exc_info=True)
  730. return jsonify({'error': str(e)}), 500
  731. finally:
  732. session.close()
  733. # 添加一个新的API端点来清空指定数据集
  734. @bp.route('/clear-dataset/<string:data_type>', methods=['DELETE'])
  735. def clear_dataset(data_type):
  736. """
  737. 清空指定类型的数据集并递增计数
  738. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  739. @return: JSON响应
  740. """
  741. # 创建 sessionmaker 实例
  742. Session = sessionmaker(bind=db.engine)
  743. session = Session()
  744. try:
  745. # 根据数据集类型选择表
  746. if data_type == 'reduce':
  747. table = CurrentReduce
  748. table_name = 'current_reduce'
  749. elif data_type == 'reflux':
  750. table = CurrentReflux
  751. table_name = 'current_reflux'
  752. else:
  753. return jsonify({'error': '无效的数据集类型'}), 400
  754. # 清空表内容
  755. session.query(table).delete()
  756. # 重置自增主键计数器
  757. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  758. session.commit()
  759. return jsonify({'message': f'{data_type} 数据集已清空并重置计数器'}), 200
  760. except Exception as e:
  761. session.rollback()
  762. return jsonify({'error': str(e)}), 500
  763. finally:
  764. session.close()
  765. @bp.route('/login', methods=['POST'])
  766. def login_user():
  767. # 获取前端传来的数据
  768. data = request.get_json()
  769. name = data.get('name') # 用户名
  770. password = data.get('password') # 密码
  771. logger.info(f"Login request received: name={name}")
  772. # 检查用户名和密码是否为空
  773. if not name or not password:
  774. logger.warning("用户名和密码不能为空")
  775. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  776. try:
  777. # 查询数据库验证用户名
  778. query = "SELECT * FROM users WHERE name = :name"
  779. conn = get_db()
  780. user = conn.execute(query, {"name": name}).fetchone()
  781. if not user:
  782. logger.warning(f"用户名 '{name}' 不存在")
  783. return jsonify({"success": False, "message": "用户名不存在"}), 400
  784. # 获取数据库中存储的密码(假设密码是哈希存储的)
  785. stored_password = user[2] # 假设密码存储在数据库的第三列
  786. user_id = user[0] # 假设 id 存储在数据库的第一列
  787. # 校验密码是否正确
  788. if check_password_hash(stored_password, password):
  789. logger.info(f"User '{name}' logged in successfully.")
  790. return jsonify({
  791. "success": True,
  792. "message": "登录成功",
  793. "userId": user_id # 返回用户 ID
  794. })
  795. else:
  796. logger.warning(f"Invalid password for user '{name}'")
  797. return jsonify({"success": False, "message": "用户名或密码错误"}), 400
  798. except Exception as e:
  799. # 记录错误日志并返回错误信息
  800. logger.error(f"Error during login: {e}", exc_info=True)
  801. return jsonify({"success": False, "message": "登录失败"}), 500
  802. # 更新用户信息接口
  803. @bp.route('/update_user', methods=['POST'])
  804. def update_user():
  805. # 获取前端传来的数据
  806. data = request.get_json()
  807. # 打印收到的请求数据
  808. bp.logger.info(f"Received data: {data}")
  809. user_id = data.get('userId') # 用户ID
  810. name = data.get('name') # 用户名
  811. old_password = data.get('oldPassword') # 旧密码
  812. new_password = data.get('newPassword') # 新密码
  813. logger.info(f"Update request received: user_id={user_id}, name={name}")
  814. # 校验传入的用户名和密码是否为空
  815. if not name or not old_password:
  816. logger.warning("用户名和旧密码不能为空")
  817. return jsonify({"success": False, "message": "用户名和旧密码不能为空"}), 400
  818. # 新密码和旧密码不能相同
  819. if new_password and old_password == new_password:
  820. logger.warning(f"新密码与旧密码相同:{name}")
  821. return jsonify({"success": False, "message": "新密码与旧密码不能相同"}), 400
  822. try:
  823. # 查询数据库验证用户ID
  824. query = "SELECT * FROM users WHERE id = :user_id"
  825. conn = get_db()
  826. user = conn.execute(query, {"user_id": user_id}).fetchone()
  827. if not user:
  828. logger.warning(f"用户ID '{user_id}' 不存在")
  829. return jsonify({"success": False, "message": "用户不存在"}), 400
  830. # 获取数据库中存储的密码(假设密码是哈希存储的)
  831. stored_password = user[2] # 假设密码存储在数据库的第三列
  832. # 校验旧密码是否正确
  833. if not check_password_hash(stored_password, old_password):
  834. logger.warning(f"旧密码错误:{name}")
  835. return jsonify({"success": False, "message": "旧密码错误"}), 400
  836. # 如果新密码非空,则更新新密码
  837. if new_password:
  838. hashed_new_password = hash_password(new_password)
  839. update_query = "UPDATE users SET password = :new_password WHERE id = :user_id"
  840. conn.execute(update_query, {"new_password": hashed_new_password, "user_id": user_id})
  841. conn.commit()
  842. logger.info(f"User ID '{user_id}' password updated successfully.")
  843. # 如果用户名发生更改,则更新用户名
  844. if name != user[1]:
  845. update_name_query = "UPDATE users SET name = :new_name WHERE id = :user_id"
  846. conn.execute(update_name_query, {"new_name": name, "user_id": user_id})
  847. conn.commit()
  848. logger.info(f"User ID '{user_id}' name updated to '{name}' successfully.")
  849. return jsonify({"success": True, "message": "用户信息更新成功"})
  850. except Exception as e:
  851. # 记录错误日志并返回错误信息
  852. logger.error(f"Error updating user: {e}", exc_info=True)
  853. return jsonify({"success": False, "message": "更新失败"}), 500
  854. # 注册用户
  855. @bp.route('/register', methods=['POST'])
  856. def register_user():
  857. # 获取前端传来的数据
  858. data = request.get_json()
  859. name = data.get('name') # 用户名
  860. password = data.get('password') # 密码
  861. logger.info(f"Register request received: name={name}")
  862. # 检查用户名和密码是否为空
  863. if not name or not password:
  864. logger.warning("用户名和密码不能为空")
  865. return jsonify({"success": False, "message": "用户名和密码不能为空"}), 400
  866. # 动态获取数据库表的列名
  867. columns = get_column_names('users')
  868. logger.info(f"Database columns for 'users' table: {columns}")
  869. # 检查前端传来的数据是否包含数据库表中所有的必填字段
  870. for column in ['name', 'password']:
  871. if column not in columns:
  872. logger.error(f"缺少必填字段:{column}")
  873. return jsonify({"success": False, "message": f"缺少必填字段:{column}"}), 400
  874. # 对密码进行哈希处理
  875. hashed_password = hash_password(password)
  876. logger.info(f"Password hashed for user: {name}")
  877. # 插入到数据库
  878. try:
  879. # 检查用户是否已经存在
  880. query = "SELECT * FROM users WHERE name = :name"
  881. conn = get_db()
  882. user = conn.execute(query, {"name": name}).fetchone()
  883. if user:
  884. logger.warning(f"用户名 '{name}' 已存在")
  885. return jsonify({"success": False, "message": "用户名已存在"}), 400
  886. # 向数据库插入数据
  887. query = "INSERT INTO users (name, password) VALUES (:name, :password)"
  888. conn.execute(query, {"name": name, "password": hashed_password})
  889. conn.commit()
  890. logger.info(f"User '{name}' registered successfully.")
  891. return jsonify({"success": True, "message": "注册成功"})
  892. except Exception as e:
  893. # 记录错误日志并返回错误信息
  894. logger.error(f"Error registering user: {e}", exc_info=True)
  895. return jsonify({"success": False, "message": "注册失败"}), 500
  896. def get_column_names(table_name):
  897. """
  898. 动态获取数据库表的列名。
  899. """
  900. try:
  901. conn = get_db()
  902. query = f"PRAGMA table_info({table_name});"
  903. result = conn.execute(query).fetchall()
  904. conn.close()
  905. return [row[1] for row in result] # 第二列是列名
  906. except Exception as e:
  907. logger.error(f"Error getting column names for table {table_name}: {e}", exc_info=True)
  908. return []
  909. # 导出数据
  910. @bp.route('/export_data', methods=['GET'])
  911. def export_data():
  912. table_name = request.args.get('table')
  913. file_format = request.args.get('format', 'excel').lower()
  914. if not table_name:
  915. return jsonify({'error': '缺少表名参数'}), 400
  916. if not table_name.isidentifier():
  917. return jsonify({'error': '无效的表名'}), 400
  918. try:
  919. conn = get_db()
  920. query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"
  921. table_exists = conn.execute(query, (table_name,)).fetchone()
  922. if not table_exists:
  923. return jsonify({'error': f"表 {table_name} 不存在"}), 404
  924. query = f"SELECT * FROM {table_name};"
  925. df = pd.read_sql(query, conn)
  926. output = BytesIO()
  927. if file_format == 'csv':
  928. df.to_csv(output, index=False, encoding='utf-8')
  929. output.seek(0)
  930. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.csv', mimetype='text/csv')
  931. elif file_format == 'excel':
  932. df.to_excel(output, index=False, engine='openpyxl')
  933. output.seek(0)
  934. return send_file(output, as_attachment=True, download_name=f'{table_name}_data.xlsx',
  935. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  936. else:
  937. return jsonify({'error': '不支持的文件格式,仅支持 CSV 和 Excel'}), 400
  938. except Exception as e:
  939. logger.error(f"Error in export_data: {e}", exc_info=True)
  940. return jsonify({'error': str(e)}), 500
  941. # 导入数据接口
  942. @bp.route('/import_data', methods=['POST'])
  943. def import_data():
  944. logger.debug("Import data endpoint accessed.")
  945. if 'file' not in request.files:
  946. logger.error("No file in request.")
  947. return jsonify({'success': False, 'message': '文件缺失'}), 400
  948. file = request.files['file']
  949. table_name = request.form.get('table')
  950. if not table_name:
  951. logger.error("Missing table name parameter.")
  952. return jsonify({'success': False, 'message': '缺少表名参数'}), 400
  953. if file.filename == '':
  954. logger.error("No file selected.")
  955. return jsonify({'success': False, 'message': '未选择文件'}), 400
  956. try:
  957. # 保存文件到临时路径
  958. temp_path = os.path.join(bp.config['UPLOAD_FOLDER'], secure_filename(file.filename))
  959. file.save(temp_path)
  960. logger.debug(f"File saved to temporary path: {temp_path}")
  961. # 根据文件类型读取文件
  962. if file.filename.endswith('.xlsx'):
  963. df = pd.read_excel(temp_path)
  964. elif file.filename.endswith('.csv'):
  965. df = pd.read_csv(temp_path)
  966. else:
  967. logger.error("Unsupported file format.")
  968. return jsonify({'success': False, 'message': '仅支持 Excel 和 CSV 文件'}), 400
  969. # 获取数据库列名
  970. db_columns = get_column_names(table_name)
  971. if 'id' in db_columns:
  972. db_columns.remove('id') # 假设 id 列是自增的,不需要处理
  973. if not set(db_columns).issubset(set(df.columns)):
  974. logger.error(f"File columns do not match database columns. File columns: {df.columns.tolist()}, Expected: {db_columns}")
  975. return jsonify({'success': False, 'message': '文件列名与数据库表不匹配'}), 400
  976. # 清洗数据并删除空值行
  977. df_cleaned = df[db_columns].dropna()
  978. # 统一数据类型,避免 int 和 float 合并问题
  979. df_cleaned[db_columns] = df_cleaned[db_columns].apply(pd.to_numeric, errors='coerce')
  980. # 获取现有的数据
  981. conn = get_db()
  982. with conn:
  983. existing_data = pd.read_sql(f"SELECT * FROM {table_name}", conn)
  984. # 查找重复数据
  985. duplicates = df_cleaned.merge(existing_data, on=db_columns, how='inner')
  986. # 如果有重复数据,删除它们
  987. df_cleaned = df_cleaned[~df_cleaned.index.isin(duplicates.index)]
  988. logger.warning(f"Duplicate data detected and removed: {duplicates}")
  989. # 获取导入前后的数据量
  990. total_data = len(df_cleaned) + len(duplicates)
  991. new_data = len(df_cleaned)
  992. duplicate_data = len(duplicates)
  993. # 导入不重复的数据
  994. df_cleaned.to_sql(table_name, conn, if_exists='append', index=False)
  995. logger.debug(f"Imported {new_data} new records into the database.")
  996. # 删除临时文件
  997. os.remove(temp_path)
  998. logger.debug(f"Temporary file removed: {temp_path}")
  999. # 返回结果
  1000. return jsonify({
  1001. 'success': True,
  1002. 'message': '数据导入成功',
  1003. 'total_data': total_data,
  1004. 'new_data': new_data,
  1005. 'duplicate_data': duplicate_data
  1006. }), 200
  1007. except Exception as e:
  1008. logger.error(f"Import failed: {e}", exc_info=True)
  1009. return jsonify({'success': False, 'message': f'导入失败: {str(e)}'}), 500
  1010. # 模板下载接口
  1011. @bp.route('/download_template', methods=['GET'])
  1012. def download_template():
  1013. """
  1014. 根据给定的表名,下载表的模板(如 CSV 或 Excel 格式)。
  1015. """
  1016. table_name = request.args.get('table')
  1017. if not table_name:
  1018. return jsonify({'error': '表名参数缺失'}), 400
  1019. columns = get_column_names(table_name)
  1020. if not columns:
  1021. return jsonify({'error': f"Table '{table_name}' not found or empty."}), 404
  1022. # 不包括 ID 列
  1023. if 'id' in columns:
  1024. columns.remove('id')
  1025. df = pd.DataFrame(columns=columns)
  1026. file_format = request.args.get('format', 'excel').lower()
  1027. try:
  1028. if file_format == 'csv':
  1029. output = BytesIO()
  1030. df.to_csv(output, index=False, encoding='utf-8')
  1031. output.seek(0)
  1032. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.csv',
  1033. mimetype='text/csv')
  1034. else:
  1035. output = BytesIO()
  1036. df.to_excel(output, index=False, engine='openpyxl')
  1037. output.seek(0)
  1038. return send_file(output, as_attachment=True, download_name=f'{table_name}_template.xlsx',
  1039. mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')
  1040. except Exception as e:
  1041. logger.error(f"Failed to generate template: {e}", exc_info=True)
  1042. return jsonify({'error': '生成模板文件失败'}), 500
  1043. @bp.route('/update-threshold', methods=['POST'])
  1044. def update_threshold():
  1045. """
  1046. 更新训练阈值的API接口
  1047. @body_param threshold: 新的阈值值(整数)
  1048. @return: JSON响应
  1049. """
  1050. try:
  1051. data = request.get_json()
  1052. new_threshold = data.get('threshold')
  1053. # 验证新阈值
  1054. if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
  1055. return jsonify({
  1056. 'error': '无效的阈值值,必须为正数'
  1057. }), 400
  1058. # 更新当前应用的阈值配置
  1059. current_app.config['THRESHOLD'] = int(new_threshold)
  1060. return jsonify({
  1061. 'success': True,
  1062. 'message': f'阈值已更新为 {new_threshold}',
  1063. 'new_threshold': new_threshold
  1064. })
  1065. except Exception as e:
  1066. logging.error(f"更新阈值失败: {str(e)}")
  1067. return jsonify({
  1068. 'error': f'更新阈值失败: {str(e)}'
  1069. }), 500
  1070. @bp.route('/get-threshold', methods=['GET'])
  1071. def get_threshold():
  1072. """
  1073. 获取当前训练阈值的API接口
  1074. @return: JSON响应
  1075. """
  1076. try:
  1077. current_threshold = current_app.config['THRESHOLD']
  1078. default_threshold = current_app.config['DEFAULT_THRESHOLD']
  1079. return jsonify({
  1080. 'current_threshold': current_threshold,
  1081. 'default_threshold': default_threshold
  1082. })
  1083. except Exception as e:
  1084. logging.error(f"获取阈值失败: {str(e)}")
  1085. return jsonify({
  1086. 'error': f'获取阈值失败: {str(e)}'
  1087. }), 500
  1088. @bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
  1089. def set_current_dataset(data_type, dataset_id):
  1090. """
  1091. 将指定数据集设置为current数据集
  1092. @param data_type: 数据集类型 ('reduce' 或 'reflux')
  1093. @param dataset_id: 要设置为current的数据集ID
  1094. @return: JSON响应
  1095. """
  1096. Session = sessionmaker(bind=db.engine)
  1097. session = Session()
  1098. try:
  1099. # 验证数据集存在且类型匹配
  1100. dataset = session.query(Datasets)\
  1101. .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
  1102. .first()
  1103. if not dataset:
  1104. return jsonify({
  1105. 'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
  1106. }), 404
  1107. # 根据数据类型选择表
  1108. if data_type == 'reduce':
  1109. table = CurrentReduce
  1110. table_name = 'current_reduce'
  1111. elif data_type == 'reflux':
  1112. table = CurrentReflux
  1113. table_name = 'current_reflux'
  1114. else:
  1115. return jsonify({'error': '无效的数据集类型'}), 400
  1116. # 清空current表
  1117. session.query(table).delete()
  1118. # 重置自增主键计数器
  1119. session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
  1120. # 从指定数据集复制数据到current表
  1121. dataset_table_name = f"dataset_{dataset_id}"
  1122. copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
  1123. session.execute(copy_sql)
  1124. session.commit()
  1125. return jsonify({
  1126. 'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
  1127. 'dataset_id': dataset_id,
  1128. 'dataset_name': dataset.Dataset_name,
  1129. 'row_count': dataset.Row_count
  1130. }), 200
  1131. except Exception as e:
  1132. session.rollback()
  1133. logger.error(f'设置current数据集失败: {str(e)}')
  1134. return jsonify({'error': str(e)}), 500
  1135. finally:
  1136. session.close()
  1137. @bp.route('/kriging_interpolation', methods=['POST'])
  1138. def kriging_interpolation():
  1139. try:
  1140. data = request.get_json()
  1141. required = ['file_name', 'emission_column', 'points']
  1142. if not all(k in data for k in required):
  1143. return jsonify({"error": "Missing parameters"}), 400
  1144. # 添加坐标顺序验证
  1145. points = data['points']
  1146. if not all(len(pt) == 2 and isinstance(pt[0], (int, float)) for pt in points):
  1147. return jsonify({"error": "Invalid points format"}), 400
  1148. result = create_kriging(
  1149. data['file_name'],
  1150. data['emission_column'],
  1151. data['points']
  1152. )
  1153. return jsonify(result)
  1154. except Exception as e:
  1155. return jsonify({"error": str(e)}), 500