drggboy 5 місяців тому
батько
коміт
bb181d4142
2 змінених файлів з 27 додано та 2 видалено
  1. 27 0
      api/app/model.py
  2. 0 2
      api/app/routes.py

+ 27 - 0
api/app/model.py

@@ -17,6 +17,33 @@ def predict(input_data: pd.DataFrame, model_name):
     return predictions.tolist()
 
 
+def train_and_save_model(dataset_id, model_type, model_name, model_description):
+    dataset = get_dataset_by_id(dataset_id)
+    if dataset.empty:
+        raise ValueError(f"Dataset {dataset_id} is empty or not found.")
+
+    # Step 1: 数据准备
+    X = dataset.iloc[:, :-1]  # 特征数据
+    y = dataset.iloc[:, -1]  # 目标变量
+
+    # Step 2: 训练模型
+    model = train_model_by_type(X, y, model_type)
+
+    # Step 3: 保存模型到数据库
+    # 使用提供的 model_name 和 model_description
+    saved_model = save_model(model_name, model_type, model_description)
+
+    # Step 4: 保存模型参数
+    save_model_parameters(model, saved_model.ModelID)
+
+    # Step 5: 计算评估指标(比如MSE)
+    y_pred = model.predict(X)
+    mse = mean_squared_error(y, y_pred)
+
+    return saved_model, mse
+
+
+
 if __name__ == '__main__':
     # 反酸模型预测
     # 测试 predict 函数

+ 0 - 2
api/app/routes.py

@@ -12,8 +12,6 @@ from sqlalchemy.orm import sessionmaker
 
 # 创建蓝图 (Blueprint),用于分离路由
 bp = Blueprint('routes', __name__)
-DATABASE = 'SoilAcidification.db'
-
 
 @bp.route('/upload-dataset', methods=['POST'])
 def upload_dataset():