model.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import datetime
  2. import os
  3. import pickle
  4. import pandas as pd
  5. from flask_sqlalchemy.session import Session
  6. from sklearn.ensemble import RandomForestRegressor
  7. from sklearn.model_selection import train_test_split, cross_val_score
  8. from sqlalchemy import text
  9. from .database_models import Models, Datasets
  10. # 加载模型
  11. def load_model(model_name):
  12. file_path = f'model_optimize/pkl/{model_name}.pkl'
  13. with open(file_path, 'rb') as f:
  14. return pickle.load(f)
  15. # 模型预测
  16. def predict(input_data: pd.DataFrame, model_name):
  17. # 初始化模型
  18. model = load_model(model_name) # 根据指定的模型名加载模型
  19. predictions = model.predict(input_data)
  20. return predictions.tolist()
  21. def train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id=None):
  22. if not dataset_id:
  23. # 直接创建新的数据集并复制数据
  24. dataset_id = save_current_dataset(session, data_type)
  25. # 从新复制的数据集表中加载数据
  26. dataset_table_name = f"dataset_{dataset_id}"
  27. dataset = pd.read_sql_table(dataset_table_name, session.bind)
  28. if dataset.empty:
  29. raise ValueError(f"Dataset {dataset_id} is empty or not found.")
  30. # 数据准备
  31. X = dataset.iloc[:, :-1]
  32. y = dataset.iloc[:, -1]
  33. # 训练模型
  34. model = train_model_by_type(X, y, model_type)
  35. # 保存模型到数据库
  36. save_model(session, model, model_name, model_type, model_description, dataset_id, data_type)
  37. # # 保存模型参数
  38. # save_model_parameters(model, saved_model.ModelID)
  39. # # 计算评估指标(如MSE)
  40. # y_pred = model.predict(X)
  41. # mse = mean_squared_error(y, y_pred)
  42. #
  43. # return saved_model, mse
  44. def save_current_dataset(session, data_type):
  45. """
  46. 创建一个新的数据集条目,并复制对应的数据类型表的数据。
  47. Args:
  48. session (Session): SQLAlchemy session对象。
  49. data_type (str): 数据集的类型,如 'reduce' 或 'reflux'。
  50. Returns:
  51. int: 新保存的数据集的ID。
  52. """
  53. # 创建一个新的数据集条目
  54. new_dataset = Datasets(
  55. Dataset_name=f"{data_type}_dataset_{datetime.datetime.now():%Y%m%d_%H%M%S}", # 使用当前时间戳生成独特的名称
  56. Dataset_description=f"Automatically generated dataset for type {data_type}",
  57. Row_count=0, # 初始行数为0,将在复制数据后更新
  58. Status='pending', # 初始状态为pending
  59. Dataset_type=data_type
  60. )
  61. # 添加到数据库并提交以获取ID
  62. session.add(new_dataset)
  63. session.flush() # flush用于立即执行SQL并获取ID,但不提交事务
  64. # 获取新数据集的ID
  65. dataset_id = new_dataset.Dataset_ID
  66. # 复制数据到新表
  67. source_table = data_type_table_mapping(data_type) # 假设有函数映射数据类型到表名
  68. new_table_name = f"dataset_{dataset_id}"
  69. copy_table_sql = f"CREATE TABLE {new_table_name} AS SELECT * FROM {source_table};"
  70. session.execute(text(copy_table_sql))
  71. # 更新新数据集的状态和行数
  72. update_sql = f"UPDATE datasets SET status='processed', row_count=(SELECT count(*) FROM {new_table_name}) WHERE dataset_id={dataset_id};"
  73. session.execute(text(update_sql))
  74. session.commit()
  75. return dataset_id
  76. def data_type_table_mapping(data_type):
  77. """映射数据类型到对应的数据库表名"""
  78. if data_type == 'reduce':
  79. return 'current_reduce'
  80. elif data_type == 'reflux':
  81. return 'current_reflux'
  82. else:
  83. raise ValueError("Invalid data type provided.")
  84. def train_model_by_type(X, y, model_type):
  85. # 划分数据集
  86. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  87. if model_type == 'RandomForest':
  88. # 随机森林的参数优化
  89. return train_random_forest(X_train, y_train)
  90. elif model_type == 'XGBR':
  91. # XGBoost的参数优化
  92. return train_xgboost(X_train, y_train)
  93. elif model_type == 'GBSTR':
  94. # 梯度提升树的参数优化
  95. return train_gradient_boosting(X_train, y_train)
  96. else:
  97. raise ValueError(f"Unsupported model type: {model_type}")
  98. def train_random_forest(X_train, y_train):
  99. best_score = 0
  100. best_n_estimators = None
  101. best_max_depth = None
  102. random_state = 43
  103. # 筛选最佳的树的数量
  104. for n_estimators in range(1, 20, 1):
  105. model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state)
  106. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  107. if score > best_score:
  108. best_score = score
  109. best_n_estimators = n_estimators
  110. print(f"Best number of trees: {best_n_estimators}, Score: {best_score}")
  111. # 在找到的最佳树的数量基础上,筛选最佳的最大深度
  112. best_score = 0 # 重置最佳得分,为最大深度优化做准备
  113. for max_depth in range(1, 30, 1):
  114. model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=max_depth, random_state=random_state)
  115. score = cross_val_score(model, X_train, y_train, cv=5).mean()
  116. if score > best_score:
  117. best_score = score
  118. best_max_depth = max_depth
  119. print(f"Best max depth: {best_max_depth}, Score: {best_score}")
  120. # 使用最佳的树的数量和最大深度训练最终模型
  121. best_model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=best_max_depth,
  122. random_state=random_state)
  123. best_model.fit(X_train, y_train)
  124. return best_model
  125. def train_xgboost(X_train, y_train, X_test, y_test):
  126. # XGBoost训练过程
  127. # (将类似上面的代码添加到这里)
  128. pass
  129. def train_gradient_boosting(X_train, y_train, X_test, y_test):
  130. # 梯度提升树训练过程
  131. # (将类似上面的代码添加到这里)
  132. pass
  133. def save_model(session, model, model_name, model_type, model_description, dataset_id, data_type, custom_path='pkl'):
  134. """
  135. 保存模型到数据库,并将模型文件保存到磁盘。
  136. :param session: 数据库会话
  137. :param model: 要保存的模型对象
  138. :param model_name: 模型的名称
  139. :param model_type: 模型的类型
  140. :param model_description: 模型的描述信息
  141. :param dataset_id: 数据集ID
  142. :param custom_path: 保存模型的路径
  143. :return: 返回保存的模型文件路径
  144. """
  145. # 根据模型类型设置文件名前缀
  146. prefix_dict = {
  147. 'RandomForest': 'rf_model_',
  148. 'XGBRegressor': 'xgbr_model_',
  149. 'GBSTRegressor': 'gbstr_model_'
  150. }
  151. prefix = prefix_dict.get(model_type, 'default_model_') # 如果model_type不在字典中,默认前缀
  152. try:
  153. # 确保路径存在
  154. os.makedirs(custom_path, exist_ok=True)
  155. # 获取当前时间戳(格式:月日时分)
  156. timestamp = datetime.datetime.now().strftime('%m%d_%H%M')
  157. # 拼接完整的文件名
  158. file_name = os.path.join(custom_path, f'{prefix}{timestamp}.pkl')
  159. # 保存模型到文件
  160. with open(file_name, 'wb') as f:
  161. pickle.dump(model, f)
  162. print(f"模型已保存为: {file_name}")
  163. # 创建模型数据库记录
  164. new_model = Models(
  165. Model_name=model_name,
  166. Model_type=model_type,
  167. Description=model_description,
  168. DatasetID=dataset_id,
  169. Created_at=datetime.datetime.now(),
  170. ModelFilePath=file_name,
  171. Data_type=data_type
  172. )
  173. # 添加记录到数据库
  174. session.add(new_model)
  175. session.commit()
  176. # 返回文件路径
  177. return file_name
  178. except Exception as e:
  179. session.rollback()
  180. print(f"Error saving model: {str(e)}")
  181. raise e # 显式抛出异常供调用者处理
  182. if __name__ == '__main__':
  183. # 反酸模型预测
  184. # 测试 predict 函数
  185. input_data = pd.DataFrame([{
  186. "organic_matter": 5.2,
  187. "chloride": 3.1,
  188. "cec": 25.6,
  189. "h_concentration": 0.5,
  190. "hn": 12.4,
  191. "al_concentration": 0.8,
  192. "free_alumina": 1.2,
  193. "free_iron": 0.9,
  194. "delta_ph": -0.2
  195. }])
  196. model_name = 'RF_filt'
  197. Acid_reflux_result = predict(input_data, model_name)
  198. print("Acid_reflux_result:", Acid_reflux_result) # 预测结果
  199. # 降酸模型预测
  200. # 测试 predict 函数
  201. input_data = pd.DataFrame([{
  202. "pH": 5.2,
  203. "OM": 3.1,
  204. "CL": 25.6,
  205. "H": 0.5,
  206. "Al": 12.4
  207. }])
  208. model_name = 'rf_model_1214_1008'
  209. Acid_reduce_result = predict(input_data, model_name)
  210. print("Acid_reduce_result:", Acid_reduce_result) # 预测结果