model.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import pickle
  2. import pandas as pd
  3. # 加载模型
  4. def load_model(model_name):
  5. file_path = f'model_optimize/pkl/{model_name}.pkl'
  6. with open(file_path, 'rb') as f:
  7. return pickle.load(f)
  8. # 模型预测
  9. def predict(input_data: pd.DataFrame, model_name):
  10. # 初始化模型
  11. model = load_model(model_name) # 根据指定的模型名加载模型
  12. predictions = model.predict(input_data)
  13. return predictions.tolist()
  14. def train_and_save_model(dataset_id, model_type, model_name, model_description):
  15. dataset = get_dataset_by_id(dataset_id)
  16. if dataset.empty:
  17. raise ValueError(f"Dataset {dataset_id} is empty or not found.")
  18. # Step 1: 数据准备
  19. X = dataset.iloc[:, :-1] # 特征数据
  20. y = dataset.iloc[:, -1] # 目标变量
  21. # Step 2: 训练模型
  22. model = train_model_by_type(X, y, model_type)
  23. # Step 3: 保存模型到数据库
  24. # 使用提供的 model_name 和 model_description
  25. saved_model = save_model(model_name, model_type, model_description)
  26. # Step 4: 保存模型参数
  27. save_model_parameters(model, saved_model.ModelID)
  28. # Step 5: 计算评估指标(比如MSE)
  29. y_pred = model.predict(X)
  30. mse = mean_squared_error(y, y_pred)
  31. return saved_model, mse
  32. if __name__ == '__main__':
  33. # 反酸模型预测
  34. # 测试 predict 函数
  35. input_data = pd.DataFrame([{
  36. "organic_matter": 5.2,
  37. "chloride": 3.1,
  38. "cec": 25.6,
  39. "h_concentration": 0.5,
  40. "hn": 12.4,
  41. "al_concentration": 0.8,
  42. "free_alumina": 1.2,
  43. "free_iron": 0.9,
  44. "delta_ph": -0.2
  45. }])
  46. model_name = 'RF_filt'
  47. Acid_reflux_result = predict(input_data, model_name)
  48. print("Acid_reflux_result:", Acid_reflux_result) # 预测结果
  49. # 降酸模型预测
  50. # 测试 predict 函数
  51. input_data = pd.DataFrame([{
  52. "pH": 5.2,
  53. "OM": 3.1,
  54. "CL": 25.6,
  55. "H": 0.5,
  56. "Al": 12.4
  57. }])
  58. model_name = 'rf_model_1214_1008'
  59. Acid_reduce_result = predict(input_data, model_name)
  60. print("Acid_reduce_result:", Acid_reduce_result) # 预测结果