routes.py 19 KB


  1. import os
  2. import logging
  3. import pickle
  4. import sqlite3
  5. import pandas as pd
  6. from flask import Blueprint, request, jsonify, current_app
  7. from sqlalchemy.orm import sessionmaker
  8. from .model import predict, train_and_save_model, calculate_model_score
  9. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  10. from . import db
  11. from .utils import create_dynamic_table, allowed_file, infer_column_types, rename_columns_for_model_predict, \
  12. clean_column_names, rename_columns_for_model, insert_data_into_dynamic_table, insert_data_into_existing_table, \
  13. predict_to_Q
  14. from sqlalchemy import text
  15. from sqlalchemy.engine.reflection import Inspector
  16. from sqlalchemy import MetaData, Table, select
  17. # 配置日志
  18. logging.basicConfig(level=logging.DEBUG)
  19. logger = logging.getLogger(__name__)
  20. # 创建蓝图 (Blueprint)
  21. bp = Blueprint('routes', __name__)
  22. def complex_reflux_calculation(parameters):
  23. """
  24. 假设的复杂 'reflux' 计算函数
  25. :param parameters: 输入的参数
  26. :return: 计算结果
  27. """
  28. try:
  29. # 加载已训练好的模型
  30. with open('path_to_your_reflux_model.pkl', 'rb') as model_file:
  31. model = pickle.load(model_file)
  32. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame(假设是这种格式)
  33. reflux_result = model.predict(input_data) # 根据实际情况调用模型的 predict 方法
  34. # 返回预测结果
  35. return reflux_result.tolist()
  36. except Exception as e:
  37. logger.error('Failed to calculate reflux model prediction:', exc_info=True)
  38. return {'error': str(e)}
  39. @bp.route('/upload-dataset', methods=['POST'])
  40. def upload_dataset():
  41. session = None
  42. try:
  43. if 'file' not in request.files:
  44. return jsonify({'error': 'No file part'}), 400
  45. file = request.files['file']
  46. if file.filename == '' or not allowed_file(file.filename):
  47. return jsonify({'error': 'No selected file or invalid file type'}), 400
  48. dataset_name = request.form.get('dataset_name')
  49. dataset_description = request.form.get('dataset_description', 'No description provided')
  50. dataset_type = request.form.get('dataset_type')
  51. if not dataset_type:
  52. return jsonify({'error': 'Dataset type is required'}), 400
  53. # 创建 sessionmaker 实例
  54. Session = sessionmaker(bind=db.engine)
  55. session = Session()
  56. # 创建新的数据集记录
  57. new_dataset = Datasets(
  58. Dataset_name=dataset_name,
  59. Dataset_description=dataset_description,
  60. Row_count=0,
  61. Status='pending',
  62. Dataset_type=dataset_type
  63. )
  64. session.add(new_dataset)
  65. session.commit()
  66. # 文件保存
  67. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  68. upload_folder = current_app.config['UPLOAD_FOLDER']
  69. file_path = os.path.join(upload_folder, unique_filename)
  70. file.save(file_path)
  71. # 处理数据集
  72. dataset_df = pd.read_excel(file_path)
  73. new_dataset.Row_count = len(dataset_df)
  74. new_dataset.Status = 'processed'
  75. session.commit()
  76. # 清理列名和数据处理
  77. dataset_df = clean_column_names(dataset_df)
  78. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  79. column_types = infer_column_types(dataset_df)
  80. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  81. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  82. # 根据 dataset_type 决定插入到哪个已有表
  83. if dataset_type == 'reduce':
  84. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  85. elif dataset_type == 'reflux':
  86. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  87. session.commit()
  88. return jsonify({
  89. 'message': f'Dataset {dataset_name} uploaded successfully!',
  90. 'dataset_id': new_dataset.Dataset_ID,
  91. 'filename': unique_filename
  92. }), 201
  93. except Exception as e:
  94. if session:
  95. session.rollback()
  96. logger.error('Failed to process the dataset upload:', exc_info=True)
  97. return jsonify({'error': str(e)}), 500
  98. finally:
  99. if session:
  100. session.close()
  101. @bp.route('/train-and-save-model', methods=['POST'])
  102. def train_and_save_model_endpoint():
  103. session = None
  104. data = request.get_json()
  105. try:
  106. # 获取请求中的数据
  107. model_type = data.get('model_type')
  108. model_name = data.get('model_name')
  109. model_description = data.get('model_description')
  110. data_type = data.get('data_type')
  111. dataset_id = data.get('dataset_id', None)
  112. # 创建 session
  113. Session = sessionmaker(bind=db.engine)
  114. session = Session()
  115. # 调用训练和保存模型的函数
  116. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  117. return jsonify({'message': 'Model trained and saved successfully', 'result': result}), 200
  118. except Exception as e:
  119. if session:
  120. session.rollback()
  121. logger.error('Failed to train and save model:', exc_info=True)
  122. return jsonify({'error': 'Failed to train and save model', 'message': str(e)}), 500
  123. finally:
  124. if session:
  125. session.close()
  126. # 预测接口
  127. @bp.route('/predict', methods=['POST'])
  128. def predict_route():
  129. session = None
  130. try:
  131. data = request.get_json()
  132. model_id = data.get('model_id')
  133. parameters = data.get('parameters', {})
  134. # 创建 session
  135. Session = sessionmaker(bind=db.engine)
  136. session = Session()
  137. # 查询模型信息
  138. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  139. if not model_info:
  140. return jsonify({'error': 'Model not found'}), 404
  141. data_type = model_info.Data_type
  142. input_data = pd.DataFrame([parameters])
  143. # 处理 'reduce' 类型模型
  144. if data_type == 'reduce':
  145. init_ph = float(parameters.get('init_pH', 0.0))
  146. target_ph = float(parameters.get('target_pH', 0.0))
  147. input_data = input_data.drop('target_pH', axis=1, errors='ignore')
  148. # 通过模型类型重命名列以适配模型
  149. input_data_rename = rename_columns_for_model_predict(input_data, data_type)
  150. # 使用当前活动的模型进行预测(假设根据 model_id 获取当前模型)
  151. predictions = predict(session, input_data_rename, model_id)
  152. # 如果是 reduce 类型,进行特定的转换
  153. if data_type == 'reduce':
  154. predictions = predictions[0]
  155. predictions = predict_to_Q(predictions, init_ph, target_ph)
  156. return jsonify({'result': predictions}), 200
  157. except Exception as e:
  158. logger.error('Failed to predict:', exc_info=True)
  159. return jsonify({'error': str(e)}), 400
  160. finally:
  161. if session:
  162. session.close()
  163. # 切换模型接口
  164. @bp.route('/switch-model', methods=['POST'])
  165. def switch_model():
  166. session = None
  167. try:
  168. data = request.get_json()
  169. model_id = data.get('model_id')
  170. model_name = data.get('model_name')
  171. # 创建 session
  172. Session = sessionmaker(bind=db.engine)
  173. session = Session()
  174. # 查找模型
  175. model = session.query(Models).filter_by(ModelID=model_id).first()
  176. if not model:
  177. return jsonify({'error': 'Model not found'}), 404
  178. # 更新模型状态(或其他切换逻辑)
  179. # 假设此处是更新模型的某些字段来进行切换
  180. model.status = 'active' # 假设有一个字段记录模型状态
  181. session.commit()
  182. # 记录切换日志
  183. logger.info(f'Model {model_name} (ID: {model_id}) switched successfully.')
  184. return jsonify({'success': True, 'message': f'Model {model_name} switched successfully!'}), 200
  185. except Exception as e:
  186. logger.error('Failed to switch model:', exc_info=True)
  187. return jsonify({'error': str(e)}), 400
  188. finally:
  189. if session:
  190. session.close()
  191. @bp.route('/models', methods=['GET'])
  192. def get_models():
  193. session = None
  194. try:
  195. # 创建 session
  196. Session = sessionmaker(bind=db.engine)
  197. session = Session()
  198. # 查询所有模型
  199. models = session.query(Models).all()
  200. logger.debug(f"Models found: {models}") # 打印查询的模型数据
  201. if not models:
  202. return jsonify({'message': 'No models found'}), 404
  203. # 将模型数据转换为字典列表
  204. models_list = [
  205. {
  206. 'ModelID': model.ModelID,
  207. 'ModelName': model.Model_name,
  208. 'ModelType': model.Model_type,
  209. 'CreatedAt': model.Created_at.strftime('%Y-%m-%d %H:%M:%S'),
  210. 'Description': model.Description,
  211. 'DatasetID': model.DatasetID,
  212. 'ModelFilePath': model.ModelFilePath,
  213. 'DataType': model.Data_type,
  214. 'PerformanceScore': model.Performance_score
  215. }
  216. for model in models
  217. ]
  218. return jsonify(models_list), 200
  219. except Exception as e:
  220. return jsonify({'error': str(e)}), 400
  221. finally:
  222. if session:
  223. session.close()
  224. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  225. def delete_dataset(dataset_id):
  226. # 创建 sessionmaker 实例
  227. Session = sessionmaker(bind=db.engine)
  228. session = Session()
  229. try:
  230. # 查询数据集
  231. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  232. if not dataset:
  233. return jsonify({'error': 'Dataset not found'}), 404
  234. # 删除文件
  235. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  236. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  237. if os.path.exists(file_path):
  238. os.remove(file_path)
  239. # 删除数据表
  240. table_name = f"dataset_{dataset.Dataset_ID}"
  241. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  242. # 删除数据集记录
  243. session.delete(dataset)
  244. session.commit()
  245. return jsonify({'message': 'Dataset deleted successfully'}), 200
  246. except Exception as e:
  247. session.rollback()
  248. logging.error(f'Failed to delete dataset {dataset_id}:', exc_info=True)
  249. return jsonify({'error': str(e)}), 500
  250. finally:
  251. session.close()
  252. @bp.route('/tables', methods=['GET'])
  253. def list_tables():
  254. engine = db.engine # 使用 db 实例的 engine
  255. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  256. table_names = inspector.get_table_names() # 获取所有表名
  257. return jsonify(table_names) # 以 JSON 形式返回表名列表
  258. @bp.route('/model-parameters', methods=['GET'])
  259. def get_all_model_parameters():
  260. try:
  261. parameters = ModelParameters.query.all() # 获取所有参数数据
  262. if parameters:
  263. result = [
  264. {
  265. 'ParamID': param.ParamID,
  266. 'ModelID': param.ModelID,
  267. 'ParamName': param.ParamName,
  268. 'ParamValue': param.ParamValue
  269. }
  270. for param in parameters
  271. ]
  272. return jsonify(result)
  273. else:
  274. return jsonify({'message': 'No parameters found'}), 404
  275. except Exception as e:
  276. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  277. # 定义添加数据库记录的 API 接口
  278. @bp.route('/add_item', methods=['POST'])
  279. def add_item():
  280. """
  281. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  282. 尝试将数据插入到指定的表中。
  283. :return:
  284. """
  285. try:
  286. # 确保请求体是JSON格式
  287. data = request.get_json()
  288. if not data:
  289. raise ValueError("No JSON data provided")
  290. table_name = data.get('table')
  291. item_data = data.get('item')
  292. if not table_name or not item_data:
  293. return jsonify({'error': 'Missing table name or item data'}), 400
  294. cur = db.cursor()
  295. # 动态构建 SQL 语句
  296. columns = ', '.join(item_data.keys())
  297. placeholders = ', '.join(['?'] * len(item_data))
  298. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  299. cur.execute(sql, tuple(item_data.values()))
  300. db.commit()
  301. # 返回更详细的成功响应
  302. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  303. except ValueError as e:
  304. return jsonify({'error': str(e)}), 400
  305. except KeyError as e:
  306. return jsonify({'error': f'Missing data field: {e}'}), 400
  307. except sqlite3.IntegrityError as e:
  308. # 处理例如唯一性约束违反等数据库完整性错误
  309. return jsonify({'error': 'Database integrity error', 'details': str(e)}), 409
  310. except sqlite3.Error as e:
  311. # 处理其他数据库错误
  312. return jsonify({'error': 'Database error', 'details': str(e)}), 500
  313. finally:
  314. db.close()
  315. # 定义删除数据库记录的 API 接口
  316. @bp.route('/delete_item', methods=['POST'])
  317. def delete_item():
  318. data = request.get_json()
  319. table_name = data.get('table')
  320. condition = data.get('condition')
  321. # 检查表名和条件是否提供
  322. if not table_name or not condition:
  323. return jsonify({
  324. "success": False,
  325. "message": "缺少表名或条件参数"
  326. }), 400
  327. # 尝试从条件字符串中分离键和值
  328. try:
  329. key, value = condition.split('=')
  330. except ValueError:
  331. return jsonify({
  332. "success": False,
  333. "message": "条件格式错误,应为 'key=value'"
  334. }), 400
  335. cur = db.cursor()
  336. try:
  337. # 执行删除操作
  338. cur.execute(f"DELETE FROM {table_name} WHERE {key} = ?", (value,))
  339. db.commit()
  340. # 如果没有错误发生,返回成功响应
  341. return jsonify({
  342. "success": True,
  343. "message": "记录删除成功"
  344. }), 200
  345. except sqlite3.Error as e:
  346. # 发生错误,回滚事务
  347. db.rollback()
  348. # 返回失败响应,并包含错误信息
  349. return jsonify({
  350. "success": False,
  351. "message": f"删除失败: {e}"
  352. }), 400
  353. # 定义修改数据库记录的 API 接口
  354. @bp.route('/update_item', methods=['PUT'])
  355. def update_record():
  356. data = request.get_json()
  357. # 检查必要的数据是否提供
  358. if not data or 'table' not in data or 'item' not in data:
  359. return jsonify({
  360. "success": False,
  361. "message": "请求数据不完整"
  362. }), 400
  363. table_name = data['table']
  364. item = data['item']
  365. # 假设 item 的第一个元素是 ID
  366. if not item or next(iter(item.keys())) is None:
  367. return jsonify({
  368. "success": False,
  369. "message": "记录数据为空"
  370. }), 400
  371. # 获取 ID 和其他字段值
  372. id_key = next(iter(item.keys()))
  373. record_id = item[id_key]
  374. updates = {key: value for key, value in item.items() if key != id_key} # 排除 ID
  375. cur = db.cursor()
  376. try:
  377. record_id = int(record_id) # 确保 ID 是整数
  378. except ValueError:
  379. return jsonify({
  380. "success": False,
  381. "message": "ID 必须是整数"
  382. }), 400
  383. # 准备参数列表,包括更新的值和 ID
  384. parameters = list(updates.values()) + [record_id]
  385. # 执行更新操作
  386. set_clause = ','.join([f"{k} = ?" for k in updates.keys()])
  387. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = ?"
  388. try:
  389. cur.execute(sql, parameters)
  390. db.commit()
  391. if cur.rowcount == 0:
  392. return jsonify({
  393. "success": False,
  394. "message": "未找到要更新的记录"
  395. }), 404
  396. return jsonify({
  397. "success": True,
  398. "message": "数据更新成功"
  399. }), 200
  400. except sqlite3.Error as e:
  401. db.rollback()
  402. return jsonify({
  403. "success": False,
  404. "message": f"更新失败: {e}"
  405. }), 400
  406. # 定义查询数据库记录的 API 接口
  407. @bp.route('/search/record', methods=['GET'])
  408. def sql_search():
  409. """
  410. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  411. 尝试查询指定 ID 的记录并返回结果。
  412. :return:
  413. """
  414. try:
  415. data = request.get_json()
  416. # 表名
  417. sql_table = data['table']
  418. # 要搜索的 ID
  419. Id = data['id']
  420. # 连接到数据库
  421. cur = db.cursor()
  422. # 构造查询语句
  423. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  424. # 执行查询
  425. cur.execute(sql, (Id,))
  426. # 获取查询结果
  427. rows = cur.fetchall()
  428. column_names = [desc[0] for desc in cur.description]
  429. # 检查是否有结果
  430. if not rows:
  431. return jsonify({'error': '未查找到对应数据。'}), 400
  432. # 构造响应数据
  433. results = []
  434. for row in rows:
  435. result = {column_names[i]: row[i] for i in range(len(row))}
  436. results.append(result)
  437. # 关闭游标和数据库连接
  438. cur.close()
  439. db.close()
  440. # 返回 JSON 响应
  441. return jsonify(results), 200
  442. except sqlite3.Error as e:
  443. # 如果发生数据库错误,返回错误信息
  444. return jsonify({'error': str(e)}), 400
  445. except KeyError as e:
  446. # 如果请求数据中缺少必要的键,返回错误信息
  447. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  448. # 定义提供数据库列表,用于展示表格的 API 接口
  449. @bp.route('/table', methods=['POST'])
  450. def get_table():
  451. data = request.get_json()
  452. table_name = data.get('table')
  453. if not table_name:
  454. return jsonify({'error': '需要表名'}), 400
  455. try:
  456. # 创建 sessionmaker 实例
  457. Session = sessionmaker(bind=db.engine)
  458. session = Session()
  459. # 动态获取表的元数据
  460. metadata = MetaData()
  461. table = Table(table_name, metadata, autoload_with=db.engine)
  462. # 从数据库中查询所有记录
  463. query = select(table)
  464. result = session.execute(query).fetchall()
  465. # 将结果转换为列表字典形式
  466. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  467. # 获取列名
  468. headers = [column.name for column in table.columns]
  469. return jsonify(rows=rows, headers=headers), 200
  470. except Exception as e:
  471. return jsonify({'error': str(e)}), 400
  472. finally:
  473. # 关闭 session
  474. session.close()