routes.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. import sqlite3
  2. from flask import Blueprint, request, jsonify, current_app
  3. from .model import predict, train_and_save_model
  4. import pandas as pd
  5. from . import db # 从 app 包导入 db 实例
  6. from sqlalchemy.engine.reflection import Inspector
  7. from .database_models import Models, ModelParameters, Datasets, CurrentReduce, CurrentReflux
  8. import os
  9. from .utils import create_dynamic_table, allowed_file
  10. from sqlalchemy.orm import sessionmaker
  11. from sqlalchemy.schema import MetaData, Table
  12. import logging
  13. from sqlalchemy import text, select
  14. # 配置日志
  15. logging.basicConfig(level=logging.DEBUG)
  16. logger = logging.getLogger(__name__)
  17. # 创建蓝图 (Blueprint),用于分离路由
  18. bp = Blueprint('routes', __name__)
  19. def infer_column_types(df):
  20. type_map = {
  21. 'object': 'str',
  22. 'int64': 'int',
  23. 'float64': 'float',
  24. 'datetime64[ns]': 'datetime' # 适应Pandas datetime类型
  25. }
  26. # 提取列和其数据类型
  27. return {col: type_map.get(str(df[col].dtype), 'str') for col in df.columns}
  28. @bp.route('/upload-dataset', methods=['POST'])
  29. def upload_dataset():
  30. try:
  31. if 'file' not in request.files:
  32. return jsonify({'error': 'No file part'}), 400
  33. file = request.files['file']
  34. if file.filename == '' or not allowed_file(file.filename):
  35. return jsonify({'error': 'No selected file or invalid file type'}), 400
  36. dataset_name = request.form.get('dataset_name')
  37. dataset_description = request.form.get('dataset_description', 'No description provided')
  38. dataset_type = request.form.get('dataset_type')
  39. if not dataset_type:
  40. return jsonify({'error': 'Dataset type is required'}), 400
  41. # 创建 sessionmaker 实例
  42. Session = sessionmaker(bind=db.engine)
  43. session = Session()
  44. new_dataset = Datasets(
  45. Dataset_name=dataset_name,
  46. Dataset_description=dataset_description,
  47. Row_count=0,
  48. Status='pending',
  49. Dataset_type=dataset_type
  50. )
  51. session.add(new_dataset)
  52. session.commit()
  53. unique_filename = f"dataset_{new_dataset.Dataset_ID}.xlsx"
  54. upload_folder = current_app.config['UPLOAD_FOLDER']
  55. file_path = os.path.join(upload_folder, unique_filename)
  56. file.save(file_path)
  57. dataset_df = pd.read_excel(file_path)
  58. new_dataset.Row_count = len(dataset_df)
  59. new_dataset.Status = 'processed'
  60. session.commit()
  61. # 清理列名
  62. dataset_df = clean_column_names(dataset_df)
  63. # 重命名 DataFrame 列以匹配模型字段
  64. dataset_df = rename_columns_for_model(dataset_df, dataset_type)
  65. column_types = infer_column_types(dataset_df)
  66. dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
  67. insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
  68. # 根据 dataset_type 决定插入到哪个已有表
  69. if dataset_type == 'reduce':
  70. insert_data_into_existing_table(session, dataset_df, CurrentReduce)
  71. elif dataset_type == 'reflux':
  72. insert_data_into_existing_table(session, dataset_df, CurrentReflux)
  73. session.commit()
  74. return jsonify({
  75. 'message': f'Dataset {dataset_name} uploaded successfully!',
  76. 'dataset_id': new_dataset.Dataset_ID,
  77. 'filename': unique_filename
  78. }), 201
  79. except Exception as e:
  80. if session:
  81. session.rollback()
  82. logging.error('Failed to process the dataset upload:', exc_info=True)
  83. return jsonify({'error': str(e)}), 500
  84. finally:
  85. session.close()
  86. @bp.route('/train-and-save-model', methods=['POST'])
  87. def train_and_save_model_endpoint():
  88. # 创建 sessionmaker 实例
  89. Session = sessionmaker(bind=db.engine)
  90. session = Session()
  91. # 从请求中解析参数
  92. data = request.get_json()
  93. model_type = data.get('model_type')
  94. model_name = data.get('model_name')
  95. model_description = data.get('model_description')
  96. data_type = data.get('data_type')
  97. dataset_id = data.get('dataset_id', None) # 默认为 None,如果未提供
  98. try:
  99. # 调用训练和保存模型的函数
  100. result = train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id)
  101. # 返回成功响应
  102. return jsonify({'message': 'Model trained and saved successfully', 'result': result}), 200
  103. except Exception as e:
  104. session.rollback()
  105. logging.error('Failed to process the dataset upload:', exc_info=True)
  106. return jsonify({'error': 'Failed to train and save model', 'message': str(e)}), 500
  107. finally:
  108. session.close()
  109. def clean_column_names(dataframe):
  110. # Strip whitespace and replace non-breaking spaces and other non-printable characters
  111. dataframe.columns = [col.strip().replace('\xa0', '') for col in dataframe.columns]
  112. return dataframe
  113. def rename_columns_for_model(dataframe, dataset_type):
  114. if dataset_type == 'reduce':
  115. rename_map = {
  116. '1/b': 'Q_over_b',
  117. 'pH': 'pH',
  118. 'OM': 'OM',
  119. 'CL': 'CL',
  120. 'H': 'H',
  121. 'Al': 'Al'
  122. }
  123. elif dataset_type == 'reflux':
  124. rename_map = {
  125. 'OM g/kg': 'OM',
  126. 'CL g/kg': 'CL',
  127. 'CEC cmol/kg': 'CEC',
  128. 'H+ cmol/kg': 'H_plus',
  129. 'HN mg/kg': 'HN',
  130. 'Al3+cmol/kg': 'Al3_plus',
  131. 'Free alumina g/kg': 'Free_alumina',
  132. 'Free iron oxides g/kg': 'Free_iron_oxides',
  133. 'ΔpH': 'Delta_pH'
  134. }
  135. # 使用 rename() 方法更新列名
  136. dataframe = dataframe.rename(columns=rename_map)
  137. return dataframe
  138. def insert_data_into_existing_table(session, dataframe, model_class):
  139. """Insert data from a DataFrame into an existing SQLAlchemy model table."""
  140. for index, row in dataframe.iterrows():
  141. record = model_class(**row.to_dict())
  142. session.add(record)
  143. def insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class):
  144. for _, row in dataset_df.iterrows():
  145. record_data = row.to_dict()
  146. session.execute(dynamic_table_class.__table__.insert(), [record_data])
  147. def insert_data_by_type(session, dataset_df, dataset_type):
  148. if dataset_type == 'reduce':
  149. for _, row in dataset_df.iterrows():
  150. record = CurrentReduce(**row.to_dict())
  151. session.add(record)
  152. elif dataset_type == 'reflux':
  153. for _, row in dataset_df.iterrows():
  154. record = CurrentReflux(**row.to_dict())
  155. session.add(record)
  156. def get_current_data(session, data_type):
  157. # 根据数据类型选择相应的表模型
  158. if data_type == 'reduce':
  159. model = CurrentReduce
  160. elif data_type == 'reflux':
  161. model = CurrentReflux
  162. else:
  163. raise ValueError("Invalid data type provided. Choose 'reduce' or 'reflux'.")
  164. # 从数据库中查询所有记录
  165. result = session.execute(select(model))
  166. # 将结果转换为DataFrame
  167. dataframe = pd.DataFrame([dict(row) for row in result])
  168. return dataframe
  169. def get_dataset_by_id(session, dataset_id):
  170. # 动态获取表的元数据
  171. metadata = MetaData(bind=session.bind)
  172. dataset_table = Table(dataset_id, metadata, autoload=True, autoload_with=session.bind)
  173. # 从数据库中查询整个表的数据
  174. query = select(dataset_table)
  175. result = session.execute(query).fetchall()
  176. # 检查是否有数据返回
  177. if not result:
  178. raise ValueError(f"No data found for dataset {dataset_id}.")
  179. # 将结果转换为DataFrame
  180. dataframe = pd.DataFrame(result, columns=[column.name for column in dataset_table.columns])
  181. return dataframe
  182. @bp.route('/delete-dataset/<int:dataset_id>', methods=['DELETE'])
  183. def delete_dataset(dataset_id):
  184. # 创建 sessionmaker 实例
  185. Session = sessionmaker(bind=db.engine)
  186. session = Session()
  187. try:
  188. # 查询数据集
  189. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  190. if not dataset:
  191. return jsonify({'error': 'Dataset not found'}), 404
  192. # 删除文件
  193. filename = f"dataset_{dataset.Dataset_ID}.xlsx"
  194. file_path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename)
  195. if os.path.exists(file_path):
  196. os.remove(file_path)
  197. # 删除数据表
  198. table_name = f"dataset_{dataset.Dataset_ID}"
  199. session.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
  200. # 删除数据集记录
  201. session.delete(dataset)
  202. session.commit()
  203. return jsonify({'message': 'Dataset deleted successfully'}), 200
  204. except Exception as e:
  205. session.rollback()
  206. logging.error(f'Failed to delete dataset {dataset_id}:', exc_info=True)
  207. return jsonify({'error': str(e)}), 500
  208. finally:
  209. session.close()
  210. @bp.route('/tables', methods=['GET'])
  211. def list_tables():
  212. engine = db.engine # 使用 db 实例的 engine
  213. inspector = Inspector.from_engine(engine) # 创建 Inspector 对象
  214. table_names = inspector.get_table_names() # 获取所有表名
  215. return jsonify(table_names) # 以 JSON 形式返回表名列表
  216. @bp.route('/models/<int:model_id>', methods=['GET'])
  217. def get_model(model_id):
  218. try:
  219. model = Models.query.filter_by(ModelID=model_id).first()
  220. if model:
  221. return jsonify({
  222. 'ModelID': model.ModelID,
  223. 'ModelName': model.ModelName,
  224. 'ModelType': model.ModelType,
  225. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  226. 'Description': model.Description
  227. })
  228. else:
  229. return jsonify({'message': 'Model not found'}), 404
  230. except Exception as e:
  231. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  232. @bp.route('/models', methods=['GET'])
  233. def get_all_models():
  234. try:
  235. models = Models.query.all() # 获取所有模型数据
  236. if models:
  237. result = [
  238. {
  239. 'ModelID': model.ModelID,
  240. 'ModelName': model.ModelName,
  241. 'ModelType': model.ModelType,
  242. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  243. 'Description': model.Description
  244. }
  245. for model in models
  246. ]
  247. return jsonify(result)
  248. else:
  249. return jsonify({'message': 'No models found'}), 404
  250. except Exception as e:
  251. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  252. @bp.route('/model-parameters', methods=['GET'])
  253. def get_all_model_parameters():
  254. try:
  255. parameters = ModelParameters.query.all() # 获取所有参数数据
  256. if parameters:
  257. result = [
  258. {
  259. 'ParamID': param.ParamID,
  260. 'ModelID': param.ModelID,
  261. 'ParamName': param.ParamName,
  262. 'ParamValue': param.ParamValue
  263. }
  264. for param in parameters
  265. ]
  266. return jsonify(result)
  267. else:
  268. return jsonify({'message': 'No parameters found'}), 404
  269. except Exception as e:
  270. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  271. @bp.route('/models/<int:model_id>/parameters', methods=['GET'])
  272. def get_model_parameters(model_id):
  273. try:
  274. model = Models.query.filter_by(ModelID=model_id).first()
  275. if model:
  276. # 获取该模型的所有参数
  277. parameters = [
  278. {
  279. 'ParamID': param.ParamID,
  280. 'ParamName': param.ParamName,
  281. 'ParamValue': param.ParamValue
  282. }
  283. for param in model.parameters
  284. ]
  285. # 返回模型参数信息
  286. return jsonify({
  287. 'ModelID': model.ModelID,
  288. 'ModelName': model.ModelName,
  289. 'ModelType': model.ModelType,
  290. 'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
  291. 'Description': model.Description,
  292. 'Parameters': parameters
  293. })
  294. else:
  295. return jsonify({'message': 'Model not found'}), 404
  296. except Exception as e:
  297. return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
  298. @bp.route('/predict', methods=['POST'])
  299. def predict_route():
  300. try:
  301. data = request.get_json()
  302. model_name = data.get('model_name') # 提取模型名称
  303. parameters = data.get('parameters', {}) # 提取所有参数
  304. input_data = pd.DataFrame([parameters]) # 转换参数为DataFrame
  305. predictions = predict(input_data, model_name) # 调用预测函数
  306. return jsonify({'predictions': predictions}), 200
  307. except Exception as e:
  308. return jsonify({'error': str(e)}), 400
  309. # 定义添加数据库记录的 API 接口
  310. @bp.route('/add_item', methods=['POST'])
  311. def add_item():
  312. """
  313. 接收 JSON 格式的请求体,包含表名和要插入的数据。
  314. 尝试将数据插入到指定的表中。
  315. :return:
  316. """
  317. try:
  318. # 确保请求体是JSON格式
  319. data = request.get_json()
  320. if not data:
  321. raise ValueError("No JSON data provided")
  322. table_name = data.get('table')
  323. item_data = data.get('item')
  324. if not table_name or not item_data:
  325. return jsonify({'error': 'Missing table name or item data'}), 400
  326. cur = db.cursor()
  327. # 动态构建 SQL 语句
  328. columns = ', '.join(item_data.keys())
  329. placeholders = ', '.join(['?'] * len(item_data))
  330. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  331. cur.execute(sql, tuple(item_data.values()))
  332. db.commit()
  333. # 返回更详细的成功响应
  334. return jsonify({'success': True, 'message': 'Item added successfully'}), 201
  335. except ValueError as e:
  336. return jsonify({'error': str(e)}), 400
  337. except KeyError as e:
  338. return jsonify({'error': f'Missing data field: {e}'}), 400
  339. except sqlite3.IntegrityError as e:
  340. # 处理例如唯一性约束违反等数据库完整性错误
  341. return jsonify({'error': 'Database integrity error', 'details': str(e)}), 409
  342. except sqlite3.Error as e:
  343. # 处理其他数据库错误
  344. return jsonify({'error': 'Database error', 'details': str(e)}), 500
  345. finally:
  346. db.close()
  347. # 定义删除数据库记录的 API 接口
  348. @bp.route('/delete_item', methods=['POST'])
  349. def delete_item():
  350. data = request.get_json()
  351. table_name = data.get('table')
  352. condition = data.get('condition')
  353. # 检查表名和条件是否提供
  354. if not table_name or not condition:
  355. return jsonify({
  356. "success": False,
  357. "message": "缺少表名或条件参数"
  358. }), 400
  359. # 尝试从条件字符串中分离键和值
  360. try:
  361. key, value = condition.split('=')
  362. except ValueError:
  363. return jsonify({
  364. "success": False,
  365. "message": "条件格式错误,应为 'key=value'"
  366. }), 400
  367. cur = db.cursor()
  368. try:
  369. # 执行删除操作
  370. cur.execute(f"DELETE FROM {table_name} WHERE {key} = ?", (value,))
  371. db.commit()
  372. # 如果没有错误发生,返回成功响应
  373. return jsonify({
  374. "success": True,
  375. "message": "记录删除成功"
  376. }), 200
  377. except sqlite3.Error as e:
  378. # 发生错误,回滚事务
  379. db.rollback()
  380. # 返回失败响应,并包含错误信息
  381. return jsonify({
  382. "success": False,
  383. "message": f"删除失败: {e}"
  384. }), 400
  385. # 定义修改数据库记录的 API 接口
  386. @bp.route('/update_item', methods=['PUT'])
  387. def update_record():
  388. data = request.get_json()
  389. # 检查必要的数据是否提供
  390. if not data or 'table' not in data or 'item' not in data:
  391. return jsonify({
  392. "success": False,
  393. "message": "请求数据不完整"
  394. }), 400
  395. table_name = data['table']
  396. item = data['item']
  397. # 假设 item 的第一个元素是 ID
  398. if not item or next(iter(item.keys())) is None:
  399. return jsonify({
  400. "success": False,
  401. "message": "记录数据为空"
  402. }), 400
  403. # 获取 ID 和其他字段值
  404. id_key = next(iter(item.keys()))
  405. record_id = item[id_key]
  406. updates = {key: value for key, value in item.items() if key != id_key} # 排除 ID
  407. cur = db.cursor()
  408. try:
  409. record_id = int(record_id) # 确保 ID 是整数
  410. except ValueError:
  411. return jsonify({
  412. "success": False,
  413. "message": "ID 必须是整数"
  414. }), 400
  415. # 准备参数列表,包括更新的值和 ID
  416. parameters = list(updates.values()) + [record_id]
  417. # 执行更新操作
  418. set_clause = ','.join([f"{k} = ?" for k in updates.keys()])
  419. sql = f"UPDATE {table_name} SET {set_clause} WHERE {id_key} = ?"
  420. try:
  421. cur.execute(sql, parameters)
  422. db.commit()
  423. if cur.rowcount == 0:
  424. return jsonify({
  425. "success": False,
  426. "message": "未找到要更新的记录"
  427. }), 404
  428. return jsonify({
  429. "success": True,
  430. "message": "数据更新成功"
  431. }), 200
  432. except sqlite3.Error as e:
  433. db.rollback()
  434. return jsonify({
  435. "success": False,
  436. "message": f"更新失败: {e}"
  437. }), 400
  438. # 定义查询数据库记录的 API 接口
  439. @bp.route('/search/record', methods=['GET'])
  440. def sql_search():
  441. """
  442. 接收 JSON 格式的请求体,包含表名和要查询的 ID。
  443. 尝试查询指定 ID 的记录并返回结果。
  444. :return:
  445. """
  446. try:
  447. data = request.get_json()
  448. # 表名
  449. sql_table = data['table']
  450. # 要搜索的 ID
  451. Id = data['id']
  452. # 连接到数据库
  453. cur = db.cursor()
  454. # 构造查询语句
  455. sql = f"SELECT * FROM {sql_table} WHERE id = ?"
  456. # 执行查询
  457. cur.execute(sql, (Id,))
  458. # 获取查询结果
  459. rows = cur.fetchall()
  460. column_names = [desc[0] for desc in cur.description]
  461. # 检查是否有结果
  462. if not rows:
  463. return jsonify({'error': '未查找到对应数据。'}), 400
  464. # 构造响应数据
  465. results = []
  466. for row in rows:
  467. result = {column_names[i]: row[i] for i in range(len(row))}
  468. results.append(result)
  469. # 关闭游标和数据库连接
  470. cur.close()
  471. db.close()
  472. # 返回 JSON 响应
  473. return jsonify(results), 200
  474. except sqlite3.Error as e:
  475. # 如果发生数据库错误,返回错误信息
  476. return jsonify({'error': str(e)}), 400
  477. except KeyError as e:
  478. # 如果请求数据中缺少必要的键,返回错误信息
  479. return jsonify({'error': f'缺少必要的数据字段: {e}'}), 400
  480. # 定义提供数据库列表,用于展示表格的 API 接口
  481. @bp.route('/table', methods=['POST'])
  482. def get_table():
  483. data = request.get_json()
  484. table_name = data.get('table')
  485. if not table_name:
  486. return jsonify({'error': '需要表名'}), 400
  487. try:
  488. # 创建 sessionmaker 实例
  489. Session = sessionmaker(bind=db.engine)
  490. session = Session()
  491. # 动态获取表的元数据
  492. metadata = MetaData()
  493. table = Table(table_name, metadata, autoload_with=db.engine)
  494. # 从数据库中查询所有记录
  495. query = select(table)
  496. result = session.execute(query).fetchall()
  497. # 将结果转换为列表字典形式
  498. rows = [dict(zip([column.name for column in table.columns], row)) for row in result]
  499. # 获取列名
  500. headers = [column.name for column in table.columns]
  501. return jsonify(rows=rows, headers=headers), 200
  502. except Exception as e:
  503. return jsonify({'error': str(e)}), 400
  504. finally:
  505. # 关闭 session
  506. session.close()