瀏覽代碼

api接口

drggboy 8 月之前
父節點
當前提交
31be29e6e5

+ 3 - 0
api/.gitignore

@@ -0,0 +1,3 @@
+app/__pycache__
+.idea
+model_optimize/__pycache__

+ 6 - 0
api/README.md

@@ -0,0 +1,6 @@
+# 环境搭建
+
+根据`environment.yml`文件新建conda环境
+~~~ 
+conda env create -f environment.yml
+~~~

+ 14 - 0
api/app/__init__.py

@@ -0,0 +1,14 @@
+from flask import Flask
+
+# 创建并配置 Flask 应用
+def create_app():
+    app = Flask(__name__)
+    
+    # 进行初始配置,加载配置文件等
+    # app.config.from_object('config.Config')
+
+    # 导入路由
+    from . import routes
+    app.register_blueprint(routes.bp)
+
+    return app

+ 4 - 0
api/app/config.py

@@ -0,0 +1,4 @@
+class Config:
+    SECRET_KEY = 'your_secret_key'
+    DEBUG = True
+    MODEL_PATH = 'D:/suan/Code_suan/RF_filt.pkl'

+ 32 - 0
api/app/model.py

@@ -0,0 +1,32 @@
+import pickle
+import pandas as pd
+
+# 加载模型
+def load_model():
+    with open('D:\suan\Code_suan\model_optimize\pkl\RF_filt.pkl', 'rb') as f:
+        return pickle.load(f)
+
+# 初始化模型
+model = load_model()   # 模型会在model.py被导入时加载一次
+
+# 预测函数
+def predict(input_data: pd.DataFrame):
+    # 假设输入数据已经准备好
+    predictions = model.predict(input_data)
+    return predictions.tolist()
+
+if __name__ == '__main__':
+    # 测试 predict 函数
+    x = pd.DataFrame([{
+        "organic_matter": 5.2,
+        "chloride": 3.1,
+        "cec": 25.6,
+        "h_concentration": 0.5,
+        "hn": 12.4,
+        "al_concentration": 0.8,
+        "free_alumina": 1.2,
+        "free_iron": 0.9,
+        "delta_ph": -0.2
+    }])
+
+    print(predict(x))  # 预测结果

+ 24 - 0
api/app/routes.py

@@ -0,0 +1,24 @@
+from flask import Blueprint, request, jsonify
+from .model import predict
+import pandas as pd
+
+# 创建蓝图 (Blueprint),用于分离路由
+bp = Blueprint('routes', __name__)
+
+# 路由:预测
+@bp.route('/predict', methods=['POST'])
+def predict_route():
+    try:
+        # 从请求中获取数据
+        data = request.get_json()
+
+        # 将数据转为 pandas DataFrame,确保数据列名一致
+        input_data = pd.DataFrame([data])
+
+        # 调用模型进行预测
+        predictions = predict(input_data)
+
+        # 返回预测结果
+        return jsonify({'predictions': predictions}), 200
+    except Exception as e:
+        return jsonify({'error': str(e)}), 400

+ 4 - 0
api/app/utils.py

@@ -0,0 +1,4 @@
+# 工具模块,用于存放一些工具函数:数据预处理、模型评估等
+def preprocess_data(data):
+    # 在此进行数据清理和转换
+    return data

二進制
api/environment.yml


+ 125 - 0
api/model_optimize/RF_filt.py

@@ -0,0 +1,125 @@
+# '''
+#模型筛选
+# '''
+
+## 导入常用基本包
+import os
+import pandas as pd
+import numpy as np
+from PIL import Image
+from model_saver import save_model
+
+# 机器学习模型导入
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.model_selection import cross_val_score,cross_val_predict
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import mean_squared_error
+import numpy as np
+import pandas as pd
+
+## 导入常用辅助函数
+from sklearn.model_selection import train_test_split
+from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import cross_val_score
+from sklearn.model_selection import cross_val_predict
+
+## 导入数据处理函数
+from sklearn.preprocessing import StandardScaler
+from sklearn.preprocessing import MinMaxScaler
+
+## 导入评分函数
+from sklearn.metrics import r2_score
+from sklearn.metrics import mean_squared_error 
+from sklearn.metrics import mean_absolute_error
+from sklearn.metrics import accuracy_score
+from sklearn.metrics import log_loss
+from sklearn.metrics import roc_auc_score
+
+
+# 导入数据
+data=pd.read_excel('model_optimize\data\data_filt.xlsx')
+x = data.iloc[:,1:10]
+y = data.iloc[:,-1]
+
+# 为 x 赋予列名
+x.columns = [
+    'organic_matter',        # OM g/kg
+    'chloride',              # CL g/kg
+    'cec',                   # CEC cmol/kg
+    'h_concentration',       # H+ cmol/kg
+    'hn',                    # HN mg/kg
+    'al_concentration',      # Al3+ cmol/kg
+    'free_alumina',          # Free alumina g/kg
+    'free_iron',             # Free iron oxides g/kg
+    'delta_ph'               # ΔpH
+]
+
+y.name = 'target_ph'
+
+Xtrain, Xtest, Ytrain, Ytest=train_test_split(x, y, test_size=0.2)
+
+# 筛选随机种子
+score_5cv_all = []
+for i in range(0, 200, 1):
+    rfc =RandomForestRegressor(random_state=i)
+    score_5cv =cross_val_score(rfc, Xtrain, Ytrain, cv=5).mean()
+    score_5cv_all.append(score_5cv)
+    pass
+score_max_5cv = max(score_5cv_all)
+
+random_state_5cv = range(0, 200)[score_5cv_all.index(max(score_5cv_all))] # 5cv最大得分对应的随机种子
+
+print("最大5cv得分:{}".format(score_max_5cv),
+      "random_5cv:{}".format(random_state_5cv))
+
+
+# 筛选随机树数目
+score_5cv_all = []
+for i in range(1, 400, 1):
+    rfc = RandomForestRegressor(n_estimators=i,
+        random_state=random_state_5cv)
+    score_5cv = cross_val_score(rfc, Xtrain, Ytrain, cv=5).mean()
+    score_5cv_all.append(score_5cv)
+    pass
+score_max_5cv = max(score_5cv_all)
+n_est_5cv = range(1,400)[score_5cv_all.index(score_max_5cv)]    # 5cv最大得分对应的树数目
+
+print("最大5cv得分:{}".format(score_max_5cv),
+      "n_est_5cv:{}".format(n_est_5cv))  # 5cv最大得分对应的树数目??
+score_test_all = []
+
+
+# 筛选最大深度
+score_5cv_all = []
+for i in range(1, 300, 1):
+    rfc = RandomForestRegressor(n_estimators=n_est_5cv
+                                , random_state=random_state_5cv
+                                , max_depth=i)
+    score_5cv = cross_val_score(rfc, Xtrain, Ytrain, cv=5).mean()
+    score_5cv_all.append(score_5cv)
+    pass
+score_max_5cv = max(score_5cv_all)
+max_depth_5cv = range(1,300)[score_5cv_all.index(score_max_5cv)]    
+print(
+      "最大5cv得分:{}".format(score_max_5cv),
+      "max_depth_5cv:{}".format(max_depth_5cv))      
+
+# 确定参数进行训练
+rfc = RandomForestRegressor(n_estimators=n_est_5cv,random_state=random_state_5cv,max_depth=max_depth_5cv)
+CV_score = cross_val_score(rfc, Xtrain, Ytrain, cv=5).mean()
+CV_predictions = cross_val_predict(rfc, Xtrain, Ytrain, cv=5)
+rmse1 = np.sqrt(mean_squared_error(Ytrain,CV_predictions))
+regressor = rfc.fit(Xtrain, Ytrain)
+test_predictions = regressor.predict(Xtest)
+score_test = regressor.score(Xtest,Ytest)
+rmse2 = np.sqrt(mean_squared_error(Ytest,test_predictions))
+print("5cv:",CV_score)
+print("rmse_5CV",rmse1)
+print("test:",score_test)
+print("rmse_test",rmse2)
+
+# 保存训练好的模型
+custom_path='model_optimize\pkl'         # 模型保存路径
+prefix='rf_model_'          # 模型文件名前缀
+save_model(rfc, custom_path, prefix)
+

二進制
api/model_optimize/data/data_filt.xlsx


+ 70 - 0
api/model_optimize/model_predict.py

@@ -0,0 +1,70 @@
+import pickle
+import pandas as pd
+import numpy as np
+from sklearn.metrics import mean_squared_error
+from pathlib import Path
+
+model_path = Path('model_optimize\pkl\RF_filt.pkl')
+# 确保路径存在
+if model_path.exists():
+    with open(model_path, 'rb') as f:
+        rfc = pickle.load(f)
+
+
+# 读取数据
+data_path = Path('model_optimize\data\data_filt.xlsx')
+data=pd.read_excel(data_path)
+
+
+x = data.iloc[:,1:10]
+y = data.iloc[:,-1]
+
+# 转换列名
+x.columns = [
+    'organic_matter',        # OM g/kg
+    'chloride',              # CL g/kg
+    'cec',                   # CEC cmol/kg
+    'h_concentration',       # H+ cmol/kg
+    'hn',                    # HN mg/kg
+    'al_concentration',      # Al3+ cmol/kg
+    'free_alumina',          # Free alumina g/kg
+    'free_iron',             # Free iron oxides g/kg
+    'delta_ph'               # ΔpH
+]
+
+
+# 预测
+y_pred = rfc.predict(x)
+
+# y 与 y_pred 的对比
+print('y:',y)
+print('y_pred:',y_pred)
+# 计算预测误差
+errors = y - y_pred
+
+# 图示
+import matplotlib.pyplot as plt
+
+# 绘制散点图
+plt.figure(figsize=(10, 6))
+plt.scatter(y, y_pred, color='blue', label='Predictions', alpha=0.5)
+plt.plot([y.min(), y.max()], [y.min(), y.max()], color='red', lw=2, label='Perfect fit')  # 理想的完美拟合线
+plt.xlabel('True Values')
+plt.ylabel('Predicted Values')
+plt.title('True vs Predicted Values')
+plt.legend()
+plt.show()
+
+# 绘制误差的直方图
+plt.figure(figsize=(10, 6))
+plt.hist(errors, bins=20, edgecolor='black', color='lightblue')
+plt.axvline(x=0, color='red', linestyle='--', lw=2, label='Zero Error Line')  # 添加零误差线
+plt.xlabel('Prediction Error')
+plt.ylabel('Frequency')
+plt.title('Distribution of Prediction Errors')
+plt.legend()
+plt.show()
+
+# 评分
+rmse = np.sqrt(mean_squared_error(y,y_pred))
+print("rmse",rmse)

+ 26 - 0
api/model_optimize/model_saver.py

@@ -0,0 +1,26 @@
+import pickle
+import datetime
+import os
+
+def save_model(model, custom_path='D:/suan/Code_suan/', prefix='my_model_'):
+    """
+    将模型保存为一个文件,文件名包括时间戳,防止覆盖。
+    
+    :param model: 训练好的模型(例如 RandomForestRegressor)
+    :param custom_path: 保存模型的路径,默认是 'D:/suan/Code_suan/'
+    :param prefix: 文件名前缀,默认是 'my_model_'
+    """
+    # 确保路径存在
+    os.makedirs(custom_path, exist_ok=True)
+    
+    # 获取当前时间戳(格式:月日时分)
+    timestamp = datetime.datetime.now().strftime('%m%d_%H%M')
+
+    # 拼接完整的文件名
+    file_name = os.path.join(custom_path, f'{prefix}{timestamp}.pkl')
+
+    # 保存模型
+    with open(file_name, 'wb') as f:
+        pickle.dump(model, f)
+
+    print(f"模型已保存为: {file_name}")

二進制
api/model_optimize/pkl/RF_filt.pkl


+ 8 - 0
api/run.py

@@ -0,0 +1,8 @@
+from app import create_app
+
+# 创建 Flask 应用
+app = create_app()
+
+# 启动服务器
+if __name__ == '__main__':
+    app.run(debug=True)

+ 1 - 0
backend

@@ -0,0 +1 @@
+Subproject commit f03205c58e724debbb130d82ade7f58a0e1458ab