import datetime import os import pickle import pandas as pd from flask_sqlalchemy.session import Session from sklearn.ensemble import RandomForestRegressor from sklearn.model_selection import train_test_split, cross_val_score from sqlalchemy import text from .database_models import Models, Datasets # 加载模型 def load_model(model_name): file_path = f'model_optimize/pkl/{model_name}.pkl' with open(file_path, 'rb') as f: return pickle.load(f) # 模型预测 def predict(input_data: pd.DataFrame, model_name): # 初始化模型 model = load_model(model_name) # 根据指定的模型名加载模型 predictions = model.predict(input_data) return predictions.tolist() def train_and_save_model(session, model_type, model_name, model_description, data_type, dataset_id=None): if not dataset_id: # 直接创建新的数据集并复制数据 dataset_id = save_current_dataset(session, data_type) # 从新复制的数据集表中加载数据 dataset_table_name = f"dataset_{dataset_id}" dataset = pd.read_sql_table(dataset_table_name, session.bind) if dataset.empty: raise ValueError(f"Dataset {dataset_id} is empty or not found.") # 数据准备 X = dataset.iloc[:, :-1] y = dataset.iloc[:, -1] # 训练模型 model = train_model_by_type(X, y, model_type) # 保存模型到数据库 save_model(session, model, model_name, model_type, model_description, dataset_id, data_type) # # 保存模型参数 # save_model_parameters(model, saved_model.ModelID) # # 计算评估指标(如MSE) # y_pred = model.predict(X) # mse = mean_squared_error(y, y_pred) # # return saved_model, mse def save_current_dataset(session, data_type): """ 创建一个新的数据集条目,并复制对应的数据类型表的数据。 Args: session (Session): SQLAlchemy session对象。 data_type (str): 数据集的类型,如 'reduce' 或 'reflux'。 Returns: int: 新保存的数据集的ID。 """ # 创建一个新的数据集条目 new_dataset = Datasets( Dataset_name=f"{data_type}_dataset_{datetime.datetime.now():%Y%m%d_%H%M%S}", # 使用当前时间戳生成独特的名称 Dataset_description=f"Automatically generated dataset for type {data_type}", Row_count=0, # 初始行数为0,将在复制数据后更新 Status='pending', # 初始状态为pending Dataset_type=data_type ) # 添加到数据库并提交以获取ID session.add(new_dataset) session.flush() # flush用于立即执行SQL并获取ID,但不提交事务 # 获取新数据集的ID dataset_id = new_dataset.Dataset_ID # 复制数据到新表 source_table = data_type_table_mapping(data_type) # 假设有函数映射数据类型到表名 new_table_name = f"dataset_{dataset_id}" copy_table_sql = f"CREATE TABLE {new_table_name} AS SELECT * FROM {source_table};" session.execute(text(copy_table_sql)) # 更新新数据集的状态和行数 update_sql = f"UPDATE datasets SET status='processed', row_count=(SELECT count(*) FROM {new_table_name}) WHERE dataset_id={dataset_id};" session.execute(text(update_sql)) session.commit() return dataset_id def data_type_table_mapping(data_type): """映射数据类型到对应的数据库表名""" if data_type == 'reduce': return 'current_reduce' elif data_type == 'reflux': return 'current_reflux' else: raise ValueError("Invalid data type provided.") def train_model_by_type(X, y, model_type): # 划分数据集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) if model_type == 'RandomForest': # 随机森林的参数优化 return train_random_forest(X_train, y_train) elif model_type == 'XGBR': # XGBoost的参数优化 return train_xgboost(X_train, y_train) elif model_type == 'GBSTR': # 梯度提升树的参数优化 return train_gradient_boosting(X_train, y_train) else: raise ValueError(f"Unsupported model type: {model_type}") def train_random_forest(X_train, y_train): best_score = 0 best_n_estimators = None best_max_depth = None random_state = 43 # 筛选最佳的树的数量 for n_estimators in range(1, 20, 1): model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state) score = cross_val_score(model, X_train, y_train, cv=5).mean() if score > best_score: best_score = score best_n_estimators = n_estimators print(f"Best number of trees: {best_n_estimators}, Score: {best_score}") # 在找到的最佳树的数量基础上,筛选最佳的最大深度 best_score = 0 # 重置最佳得分,为最大深度优化做准备 for max_depth in range(1, 30, 1): model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=max_depth, random_state=random_state) score = cross_val_score(model, X_train, y_train, cv=5).mean() if score > best_score: best_score = score best_max_depth = max_depth print(f"Best max depth: {best_max_depth}, Score: {best_score}") # 使用最佳的树的数量和最大深度训练最终模型 best_model = RandomForestRegressor(n_estimators=best_n_estimators, max_depth=best_max_depth, random_state=random_state) best_model.fit(X_train, y_train) return best_model def train_xgboost(X_train, y_train, X_test, y_test): # XGBoost训练过程 # (将类似上面的代码添加到这里) pass def train_gradient_boosting(X_train, y_train, X_test, y_test): # 梯度提升树训练过程 # (将类似上面的代码添加到这里) pass def save_model(session, model, model_name, model_type, model_description, dataset_id, data_type, custom_path='pkl'): """ 保存模型到数据库,并将模型文件保存到磁盘。 :param session: 数据库会话 :param model: 要保存的模型对象 :param model_name: 模型的名称 :param model_type: 模型的类型 :param model_description: 模型的描述信息 :param dataset_id: 数据集ID :param custom_path: 保存模型的路径 :return: 返回保存的模型文件路径 """ # 根据模型类型设置文件名前缀 prefix_dict = { 'RandomForest': 'rf_model_', 'XGBRegressor': 'xgbr_model_', 'GBSTRegressor': 'gbstr_model_' } prefix = prefix_dict.get(model_type, 'default_model_') # 如果model_type不在字典中,默认前缀 try: # 确保路径存在 os.makedirs(custom_path, exist_ok=True) # 获取当前时间戳(格式:月日时分) timestamp = datetime.datetime.now().strftime('%m%d_%H%M') # 拼接完整的文件名 file_name = os.path.join(custom_path, f'{prefix}{timestamp}.pkl') # 保存模型到文件 with open(file_name, 'wb') as f: pickle.dump(model, f) print(f"模型已保存为: {file_name}") # 创建模型数据库记录 new_model = Models( Model_name=model_name, Model_type=model_type, Description=model_description, DatasetID=dataset_id, Created_at=datetime.datetime.now(), ModelFilePath=file_name, Data_type=data_type ) # 添加记录到数据库 session.add(new_model) session.commit() # 返回文件路径 return file_name except Exception as e: session.rollback() print(f"Error saving model: {str(e)}") raise e # 显式抛出异常供调用者处理 if __name__ == '__main__': # 反酸模型预测 # 测试 predict 函数 input_data = pd.DataFrame([{ "organic_matter": 5.2, "chloride": 3.1, "cec": 25.6, "h_concentration": 0.5, "hn": 12.4, "al_concentration": 0.8, "free_alumina": 1.2, "free_iron": 0.9, "delta_ph": -0.2 }]) model_name = 'RF_filt' Acid_reflux_result = predict(input_data, model_name) print("Acid_reflux_result:", Acid_reflux_result) # 预测结果 # 降酸模型预测 # 测试 predict 函数 input_data = pd.DataFrame([{ "pH": 5.2, "OM": 3.1, "CL": 25.6, "H": 0.5, "Al": 12.4 }]) model_name = 'rf_model_1214_1008' Acid_reduce_result = predict(input_data, model_name) print("Acid_reduce_result:", Acid_reduce_result) # 预测结果