model.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import pickle
  2. import pandas as pd
  3. # 加载模型
  4. def load_model(model_name):
  5. file_path = f'model_optimize/pkl/{model_name}.pkl'
  6. with open(file_path, 'rb') as f:
  7. return pickle.load(f)
  8. # 模型预测
  9. def predict(input_data: pd.DataFrame, model_name):
  10. # 初始化模型
  11. model = load_model(model_name) # 根据指定的模型名加载模型
  12. predictions = model.predict(input_data)
  13. return predictions.tolist()
  14. if __name__ == '__main__':
  15. # 反酸模型预测
  16. # 测试 predict 函数
  17. input_data = pd.DataFrame([{
  18. "organic_matter": 5.2,
  19. "chloride": 3.1,
  20. "cec": 25.6,
  21. "h_concentration": 0.5,
  22. "hn": 12.4,
  23. "al_concentration": 0.8,
  24. "free_alumina": 1.2,
  25. "free_iron": 0.9,
  26. "delta_ph": -0.2
  27. }])
  28. model_name = 'RF_filt'
  29. Acid_reflux_result = predict(input_data, model_name)
  30. print("Acid_reflux_result:", Acid_reflux_result) # 预测结果
  31. # 降酸模型预测
  32. # 测试 predict 函数
  33. input_data = pd.DataFrame([{
  34. "pH": 5.2,
  35. "OM": 3.1,
  36. "CL": 25.6,
  37. "H": 0.5,
  38. "Al": 12.4
  39. }])
  40. model_name = 'rf_model_1214_1008'
  41. Acid_reduce_result = predict(input_data, model_name)
  42. print("Acid_reduce_result:", Acid_reduce_result) # 预测结果