Browse Source

降酸模型文件导入与接口优化

drggboy 5 months ago
parent
commit
fa852071cc

+ 28 - 12
api/app/model.py

@@ -3,25 +3,23 @@ import pandas as pd
 
 
 # 加载模型
-def load_model():
-    with open('model_optimize/pkl/RF_filt.pkl', 'rb') as f:
+def load_model(model_name='RF_filt'):
+    file_path = f'model_optimize/pkl/{model_name}.pkl'
+    with open(file_path, 'rb') as f:
         return pickle.load(f)
 
-
-# 初始化模型
-model = load_model()   # 模型会在model.py被导入时加载一次
-
-
-# 预测函数
-def predict(input_data: pd.DataFrame):
-    # 假设输入数据已经准备好
+# 模型预测
+def predict(input_data: pd.DataFrame, model_name='RF_filt'):
+    # 初始化模型
+    model = load_model(model_name)  # 根据指定的模型名加载模型
     predictions = model.predict(input_data)
     return predictions.tolist()
 
 
 if __name__ == '__main__':
+    # 反酸模型预测
     # 测试 predict 函数
-    x = pd.DataFrame([{
+    input_data = pd.DataFrame([{
         "organic_matter": 5.2,
         "chloride": 3.1,
         "cec": 25.6,
@@ -33,4 +31,22 @@ if __name__ == '__main__':
         "delta_ph": -0.2
     }])
 
-    print(predict(x))  # 预测结果
+    model_name = 'RF_filt'
+
+    Acid_reflux_result = predict(input_data, model_name)
+    print("Acid_reflux_result:", Acid_reflux_result)  # 预测结果
+
+
+    # 降酸模型预测
+    # 测试 predict 函数
+    input_data = pd.DataFrame([{
+        "pH": 5.2,
+        "OM": 3.1,
+        "CL": 25.6,
+        "H": 0.5,
+        "Al": 12.4
+    }])
+
+    model_name = 'rf_model_1214_1008'
+    Acid_reduce_result = predict(input_data, model_name)
+    print("Acid_reduce_result:", Acid_reduce_result)  # 预测结果

BIN
api/model_optimize/pkl/rf_model_1207_1530.pkl


BIN
api/model_optimize/pkl/rf_model_1214_1008.pkl