model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import datetime
  2. import os
  3. import pickle
  4. import pandas as pd
  5. from flask_sqlalchemy.session import Session
  6. from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
  7. from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
  8. from sklearn.model_selection import train_test_split, cross_val_score
  9. from sqlalchemy import text
  10. from xgboost import XGBRegressor
  11. import logging
  12. import numpy as np
  13. from .database_models import Models, Datasets
  14. from .config import Config
  15. from .data_cleaner import clean_dataset
  16. # 加载模型
  17. def load_model(session, model_id):
  18. model = session.query(Models).filter(Models.ModelID == model_id).first()
  19. if not model:
  20. raise ValueError(f"Model with ID {model_id} not found.")
  21. with open(model.ModelFilePath, 'rb') as f:
  22. return pickle.load(f)
  23. # 模型预测
  24. def predict(session, input_data: pd.DataFrame, model_id):
  25. # 初始化模型
  26. ML_model = load_model(session, model_id) # 根据指定的模型名加载模型
  27. # model = load_model(model_id) # 根据指定的模型名加载模型
  28. predictions = ML_model.predict(input_data)
  29. return predictions.tolist()
  30. def check_dataset_overlap_with_test(dataset_df, data_type):
  31. """
  32. 检查数据集是否与测试集有重叠
  33. Args:
  34. dataset_df (DataFrame): 要检查的数据集
  35. data_type (str): 数据集类型 ('reflux' 或 'reduce')
  36. Returns:
  37. tuple: (重叠的行数, 重叠的行索引)
  38. """
  39. # 加载测试集
  40. if data_type == 'reflux':
  41. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  42. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  43. elif data_type == 'reduce':
  44. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  45. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  46. else:
  47. raise ValueError(f"不支持的数据类型: {data_type}")
  48. # 合并X_test和Y_test
  49. if data_type == 'reflux':
  50. test_df = pd.concat([X_test, Y_test], axis=1)
  51. else:
  52. test_df = pd.concat([X_test, Y_test], axis=1)
  53. # 确定用于比较的列
  54. compare_columns = [col for col in dataset_df.columns if col in test_df.columns]
  55. if not compare_columns:
  56. return 0, []
  57. # 查找重叠的行
  58. merged = dataset_df[compare_columns].merge(test_df[compare_columns], how='inner', indicator=True)
  59. overlapping_rows = merged[merged['_merge'] == 'both']
  60. # 获取重叠行在原始数据集中的索引
  61. if not overlapping_rows.empty:
  62. # 使用合并后的数据找回原始索引
  63. overlap_indices = []
  64. for _, row in overlapping_rows.iterrows():
  65. # 创建一个布尔掩码,用于在原始数据集中查找匹配的行
  66. mask = True
  67. for col in compare_columns:
  68. mask = mask & (dataset_df[col] == row[col])
  69. # 获取匹配行的索引
  70. matching_indices = dataset_df[mask].index.tolist()
  71. overlap_indices.extend(matching_indices)
  72. return len(set(overlap_indices)), list(set(overlap_indices))
  73. return 0, []
  74. # 计算模型评分
  75. def calculate_model_score(model_info):
  76. """
  77. 计算模型评分
  78. Args:
  79. model_info: 模型信息对象
  80. Returns:
  81. dict: 包含多种评分指标的字典
  82. """
  83. # 加载模型
  84. with open(model_info.ModelFilePath, 'rb') as f:
  85. ML_model = pickle.load(f)
  86. # print("Model requires the following features:", model.feature_names_in_)
  87. # 数据准备
  88. if model_info.Data_type == 'reflux': # 反酸数据集
  89. # 加载保存的 X_test 和 Y_test
  90. X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
  91. Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
  92. # 预测测试集
  93. y_pred = ML_model.predict(X_test)
  94. # 计算各种评分指标
  95. r2 = r2_score(Y_test, y_pred)
  96. mae = mean_absolute_error(Y_test, y_pred)
  97. rmse = np.sqrt(mean_squared_error(Y_test, y_pred))
  98. elif model_info.Data_type == 'reduce': # 降酸数据集
  99. # 加载保存的 X_test 和 Y_test
  100. X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
  101. Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
  102. # 预测测试集
  103. y_pred = ML_model.predict(X_test)
  104. # 计算各种评分指标
  105. r2 = r2_score(Y_test, y_pred)
  106. mae = mean_absolute_error(Y_test, y_pred)
  107. rmse = np.sqrt(mean_squared_error(Y_test, y_pred))
  108. else:
  109. # 不支持的数据类型
  110. return {'r2': 0, 'mae': 0, 'rmse': 0}
  111. # 返回所有评分指标(不包括交叉验证得分)
  112. return {
  113. 'r2': float(r2),
  114. 'mae': float(mae),
  115. 'rmse': float(rmse)
  116. }
  117. def train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id=None):
  118. """
  119. 训练并保存模型
  120. Args:
  121. session: 数据库会话
  122. model_type: 模型类型
  123. model_name: 模型名称
  124. model_description: 模型描述
  125. data_type: 数据类型 ('reflux' 或 'reduce')
  126. dataset_id: 数据集ID
  127. Returns:
  128. tuple: (模型名称, 模型ID, 数据集ID)
  129. """
  130. try:
  131. if not dataset_id:
  132. # 创建新的数据集并复制数据,此过程将不立即提交
  133. dataset_id = save_current_dataset(session, data_type, commit=False)
  134. if data_type == 'reflux':
  135. current_table = 'current_reflux'
  136. elif data_type == 'reduce':
  137. current_table = 'current_reduce'
  138. # 从current数据集表中加载数据
  139. dataset = pd.read_sql_table(current_table, session.bind)
  140. elif dataset_id:
  141. # 从新复制的数据集表中加载数据
  142. dataset_table_name = f"dataset_{dataset_id}"
  143. dataset = pd.read_sql_table(dataset_table_name, session.bind)
  144. if dataset.empty:
  145. raise ValueError(f"Dataset {dataset_id} is empty or not found.")
  146. # 使用数据清理模块
  147. if data_type == 'reflux':
  148. X = dataset.iloc[:, 1:-1]
  149. y = dataset.iloc[:, -1]
  150. # target_column = -1 # 假设目标变量在最后一列
  151. # X, y, clean_stats = clean_dataset(dataset, target_column=target_column)
  152. elif data_type == 'reduce':
  153. X = dataset.iloc[:, 2:]
  154. y = dataset.iloc[:, 1]
  155. # target_column = 1 # 假设目标变量在第二列
  156. # X, y, clean_stats = clean_dataset(dataset, target_column=target_column)
  157. # 记录清理统计信息
  158. # logging.info(f"数据清理统计: {clean_stats}")
  159. # 训练模型
  160. model = train_model_by_type(X, y, model_type)
  161. # 计算交叉验证得分
  162. cv_score = cross_val_score(model, X, y, cv=5).mean()
  163. # 保存模型到数据库
  164. model_id = save_model(session, model, model_name, model_type, model_description, dataset_id, data_type)
  165. # 更新模型的交叉验证得分
  166. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  167. if model_info:
  168. model_info.CV_score = float(cv_score)
  169. session.commit()
  170. # 所有操作成功后,手动提交事务
  171. session.commit()
  172. return model_name, model_id, dataset_id, cv_score
  173. except Exception as e:
  174. session.rollback()
  175. logging.error(f"训练和保存模型时发生错误: {str(e)}", exc_info=True)
  176. raise
  177. def save_current_dataset(session, data_type, commit=True):
  178. """
  179. 创建一个新的数据集条目,并复制对应的数据类型表的数据,但不立即提交事务。
  180. Args:
  181. session (Session): SQLAlchemy session对象。
  182. data_type (str): 数据集的类型。
  183. commit (bool): 是否在函数结束时提交事务。
  184. Returns:
  185. int: 新保存的数据集的ID。
  186. """
  187. current_time = datetime.datetime.now()
  188. new_dataset = Datasets(
  189. Dataset_name=f"{data_type}_dataset_{current_time:%Y%m%d_%H%M%S}",
  190. Dataset_description=f"Automatically generated dataset for type {data_type}",
  191. Row_count=0,
  192. Status='pending',
  193. Dataset_type=data_type,
  194. Uploaded_at=current_time
  195. )
  196. session.add(new_dataset)
  197. session.flush()
  198. dataset_id = new_dataset.Dataset_ID
  199. source_table = data_type_table_mapping(data_type)
  200. new_table_name = f"dataset_{dataset_id}"
  201. session.execute(text(f"CREATE TABLE {new_table_name} AS SELECT * FROM {source_table};"))
  202. session.execute(text(f"UPDATE datasets SET status='Datasets upgraded success', row_count=(SELECT count(*) FROM {new_table_name}) WHERE dataset_id={dataset_id};"))
  203. if commit:
  204. session.commit()
  205. return dataset_id
  206. def data_type_table_mapping(data_type):
  207. """映射数据类型到对应的数据库表名"""
  208. if data_type == 'reduce':
  209. return 'current_reduce'
  210. elif data_type == 'reflux':
  211. return 'current_reflux'
  212. else:
  213. raise ValueError("Invalid data type provided.")
  214. def train_model_by_type(X, y, model_type):
  215. # 划分数据集
  216. # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  217. # 使用全部数据作为训练集
  218. X_train, y_train = X, y
  219. if model_type == 'RandomForest':
  220. # 随机森林的参数优化
  221. return train_random_forest(X_train, y_train)
  222. elif model_type == 'XGBR':
  223. # XGBoost的参数优化
  224. return train_xgboost(X_train, y_train)
  225. elif model_type == 'GBSTR':
  226. # 梯度提升树的参数优化
  227. return train_gradient_boosting(X_train, y_train)
  228. else:
  229. raise ValueError(f"Unsupported model type: {model_type}")
  230. def train_random_forest(X_train, y_train):
  231. best_score = -float('inf')
  232. best_n_estimators = None
  233. best_max_depth = None
  234. random_state = 43
  235. # 筛选最佳的树的数量
  236. for n_estimators in range(1, 20, 1):
  237. model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state)
  238. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  239. if score > best_score:
  240. best_score = score
  241. best_n_estimators = n_estimators
  242. print(f"Best number of trees: {best_n_estimators}, Score: {best_score}")
  243. # 在找到的最佳树的数量基础上,筛选最佳的最大深度
  244. best_score = 0 # 重置最佳得分,为最大深度优化做准备
  245. for max_depth in range(1, 5, 1):
  246. model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=max_depth, random_state=random_state)
  247. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  248. if score > best_score:
  249. best_score = score
  250. best_max_depth = max_depth
  251. print(f"Best max depth: {best_max_depth}, Score: {best_score}")
  252. # 使用最佳的树的数量和最大深度训练最终模型
  253. best_model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=best_max_depth,
  254. random_state=random_state)
  255. # 传入列名进行训练
  256. best_model.fit(X_train, y_train)
  257. # 指定传入的特征名
  258. best_model.feature_names_in_ = X_train.columns
  259. return best_model
  260. def train_xgboost(X_train, y_train):
  261. best_score = -float('inf')
  262. best_params = {'learning_rate': None, 'max_depth': None}
  263. random_state = 43
  264. for learning_rate in [0.01, 0.05, 0.1, 0.2]:
  265. for max_depth in range(3, 10):
  266. model = XGBRegressor(learning_rate=learning_rate, max_depth=max_depth, random_state=random_state)
  267. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  268. if score > best_score:
  269. best_score = score
  270. best_params['learning_rate'] = learning_rate
  271. best_params['max_depth'] = max_depth
  272. print(f"Best parameters: {best_params}, Score: {best_score}")
  273. # 使用找到的最佳参数训练最终模型
  274. best_model = XGBRegressor(learning_rate=best_params['learning_rate'], max_depth=best_params['max_depth'],
  275. random_state=random_state)
  276. best_model.fit(X_train, y_train)
  277. return best_model
  278. def train_gradient_boosting(X_train, y_train):
  279. best_score = -float('inf')
  280. best_params = {'learning_rate': None, 'max_depth': None}
  281. random_state = 43
  282. for learning_rate in [0.01, 0.05, 0.1, 0.2]:
  283. for max_depth in range(3, 10):
  284. model = GradientBoostingRegressor(learning_rate=learning_rate, max_depth=max_depth, random_state=random_state)
  285. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  286. if score > best_score:
  287. best_score = score
  288. best_params['learning_rate'] = learning_rate
  289. best_params['max_depth'] = max_depth
  290. print(f"Best parameters: {best_params}, Score: {best_score}")
  291. # 使用找到的最佳参数训练最终模型
  292. best_model = GradientBoostingRegressor(learning_rate=best_params['learning_rate'], max_depth=best_params['max_depth'],
  293. random_state=random_state)
  294. best_model.fit(X_train, y_train)
  295. return best_model
  296. def save_model(session, model, model_name, model_type, model_description, dataset_id, data_type, commit=False):
  297. """
  298. 保存模型到数据库,并将模型文件保存到磁盘。
  299. Args:
  300. session: 数据库会话
  301. model: 要保存的模型对象
  302. model_name: 模型的名称
  303. model_type: 模型的类型
  304. model_description: 模型的描述信息
  305. dataset_id: 数据集ID
  306. data_type: 数据类型
  307. commit: 是否提交事务
  308. Returns:
  309. int: 返回保存的模型ID
  310. """
  311. prefix_dict = {
  312. 'RandomForest': 'rf_model_',
  313. 'XGBR': 'xgbr_model_',
  314. 'GBSTR': 'gbstr_model_'
  315. }
  316. prefix = prefix_dict.get(model_type, 'default_model_')
  317. try:
  318. # 从配置中获取保存路径
  319. model_save_path = Config.MODEL_SAVE_PATH
  320. # 确保路径存在
  321. os.makedirs(model_save_path, exist_ok=True)
  322. # 获取当前时间戳
  323. timestamp = datetime.datetime.now().strftime('%m%d_%H%M')
  324. # 拼接完整的文件名
  325. file_name = os.path.join(model_save_path, f'{prefix}{timestamp}.pkl')
  326. # 保存模型到文件
  327. with open(file_name, 'wb') as f:
  328. pickle.dump(model, f)
  329. print(f"模型已保存至: {file_name}")
  330. # 创建模型数据库记录
  331. new_model = Models(
  332. Model_name=model_name,
  333. Model_type=model_type,
  334. Description=model_description,
  335. DatasetID=dataset_id,
  336. Created_at=datetime.datetime.now(),
  337. ModelFilePath=file_name,
  338. Data_type=data_type
  339. )
  340. # 添加记录到数据库
  341. session.add(new_model)
  342. session.flush()
  343. return new_model.ModelID
  344. except Exception as e:
  345. print(f"保存模型时发生错误: {str(e)}")
  346. raise
  347. if __name__ == '__main__':
  348. # 反酸模型预测
  349. # 测试 predict 函数
  350. input_data = pd.DataFrame([{
  351. "organic_matter": 5.2,
  352. "chloride": 3.1,
  353. "cec": 25.6,
  354. "h_concentration": 0.5,
  355. "hn": 12.4,
  356. "al_concentration": 0.8,
  357. "free_alumina": 1.2,
  358. "free_iron": 0.9,
  359. "delta_ph": -0.2
  360. }])
  361. model_name = 'RF_filt'
  362. Acid_reflux_result = predict(input_data, model_name)
  363. print("Acid_reflux_result:", Acid_reflux_result) # 预测结果
  364. # 降酸模型预测
  365. # 测试 predict 函数
  366. input_data = pd.DataFrame([{
  367. "pH": 5.2,
  368. "OM": 3.1,
  369. "CL": 25.6,
  370. "H": 0.5,
  371. "Al": 12.4
  372. }])
  373. model_name = 'rf_model_1214_1008'
  374. Acid_reduce_result = predict(input_data, model_name)
  375. print("Acid_reduce_result:", Acid_reduce_result) # 预测结果