Browse Source

Merge branch 'master' of http://121.37.22.236:13000/Ding/AcidificationModel

# Conflicts:
#	api/SoilAcidification.db
yangtaodemon 8 months ago
parent
commit
eb3975ea4f
7 changed files with 337 additions and 19 deletions
  1. BIN
      api/SoilAcidification.db
  2. 8 2
      api/app/__init__.py
  3. 4 3
      api/app/config.py
  4. 48 0
      api/app/database_models.py
  5. 27 0
      api/app/model.py
  6. 193 10
      api/app/routes.py
  7. 57 4
      api/app/utils.py

BIN
api/SoilAcidification.db


+ 8 - 2
api/app/__init__.py

@@ -1,16 +1,22 @@
 from flask import Flask
 from flask_cors import CORS
+from . import config
+from flask_sqlalchemy import SQLAlchemy
 
+# 创建 SQLAlchemy 全局实例
+db = SQLAlchemy()
 
 # 创建并配置 Flask 应用
 def create_app():
     app = Flask(__name__)
     CORS(app)
     # 进行初始配置,加载配置文件等
-    # app.config.from_object('config.Config')
+    app.config.from_object(config.Configs)
+
+    # 初始化 SQLAlchemy
+    db.init_app(app)
 
     # 导入路由
     from . import routes
     app.register_blueprint(routes.bp)
-
     return app

+ 4 - 3
api/app/config.py

@@ -1,8 +1,9 @@
 import os
 
-class Config:
+class Configs:
     SECRET_KEY = 'your_secret_key'
     DEBUG = True
     MODEL_PATH = 'model_optimize/pkl/RF_filt.pkl'
-    DATABASE = os.path.join(os.path.dirname(__file__), 'SoilAcidification.db')
-    UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
+    DATABASE = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'SoilAcidification.db')
+    SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE}'
+    UPLOAD_FOLDER = 'uploads/datasets'

+ 48 - 0
api/app/database_models.py

@@ -0,0 +1,48 @@
+from . import db
+from datetime import datetime
+
+class Model(db.Model):
+    __tablename__ = 'Models'
+
+    ModelID = db.Column(db.Integer, primary_key=True)
+    ModelName = db.Column(db.Text, nullable=False)
+    ModelType = db.Column(db.Text, nullable=False)
+    CreatedAt = db.Column(db.TIMESTAMP, default=db.func.current_timestamp())
+    Description = db.Column(db.Text)
+    
+
+class ModelParameters(db.Model):
+    __tablename__ = 'ModelParameters'  # 指定表名
+
+    ParamID = db.Column(db.Integer, primary_key=True, autoincrement=True)  # 主键
+    ModelID = db.Column(db.Integer, db.ForeignKey('Models.ModelID'), nullable=False)  # 外键,指向 Models 表的 ModelID
+    ParamName = db.Column(db.Text, nullable=False)  # 参数名
+    ParamValue = db.Column(db.Text, nullable=False)  # 参数值
+
+    # 定义反向关系
+    model = db.relationship('Model', backref=db.backref('parameters', lazy=True))
+
+class Dataset(db.Model):
+    __tablename__ = 'Datasets'
+
+    DatasetID = db.Column(db.Integer, primary_key=True)  # 数据集ID
+    DatasetName = db.Column(db.String(255), nullable=False)  # 数据集名称
+    DatasetDescription = db.Column(db.Text, nullable=True)  # 数据集描述 (可选)
+    UploadedAt = db.Column(db.TIMESTAMP, default=datetime.utcnow)  # 上传时间
+    RowCount = db.Column(db.Integer, nullable=False)  # 数据集行数(条数)
+    Status = db.Column(db.String(50), default='pending')  # 数据集状态 (pending, processed)
+    DatasetType = db.Column(db.String(50), nullable=False)  # 数据集类型 (反酸模型训练, 降酸模型训练等)
+
+    def __repr__(self):
+        return f'<Dataset {self.DatasetName}>'
+
+    def to_dict(self):
+        return {
+            'DatasetID': self.DatasetID,
+            'DatasetName': self.DatasetName,
+            'DatasetDescription': self.DatasetDescription,
+            'UploadedAt': self.UploadedAt.strftime('%Y-%m-%d %H:%M:%S'),
+            'RowCount': self.RowCount,
+            'Status': self.Status,
+            'DatasetType': self.DatasetType
+        }

+ 27 - 0
api/app/model.py

@@ -17,6 +17,33 @@ def predict(input_data: pd.DataFrame, model_name):
     return predictions.tolist()
 
 
+def train_and_save_model(dataset_id, model_type, model_name, model_description):
+    dataset = get_dataset_by_id(dataset_id)
+    if dataset.empty:
+        raise ValueError(f"Dataset {dataset_id} is empty or not found.")
+
+    # Step 1: 数据准备
+    X = dataset.iloc[:, :-1]  # 特征数据
+    y = dataset.iloc[:, -1]  # 目标变量
+
+    # Step 2: 训练模型
+    model = train_model_by_type(X, y, model_type)
+
+    # Step 3: 保存模型到数据库
+    # 使用提供的 model_name 和 model_description
+    saved_model = save_model(model_name, model_type, model_description)
+
+    # Step 4: 保存模型参数
+    save_model_parameters(model, saved_model.ModelID)
+
+    # Step 5: 计算评估指标(比如MSE)
+    y_pred = model.predict(X)
+    mse = mean_squared_error(y, y_pred)
+
+    return saved_model, mse
+
+
+
 if __name__ == '__main__':
     # 反酸模型预测
     # 测试 predict 函数

+ 193 - 10
api/app/routes.py

@@ -1,21 +1,204 @@
 import sqlite3
 
-from flask_sqlalchemy import SQLAlchemy
-from flask import Blueprint, request, jsonify, g, current_app
-from .model import predict
+from flask import Blueprint, request, jsonify, current_app
+from .model import predict, train_and_save_model
 import pandas as pd
-
+from . import db  # 从 app 包导入 db 实例
+from sqlalchemy.engine.reflection import Inspector
+from .database_models import Model, ModelParameters, Dataset
+import os
+from .utils import create_dynamic_table, allowed_file
+from sqlalchemy.orm import sessionmaker
 
 # 创建蓝图 (Blueprint),用于分离路由
 bp = Blueprint('routes', __name__)
-DATABASE = 'SoilAcidification.db'
 
+@bp.route('/upload-dataset', methods=['POST'])
+def upload_dataset():
+    try:
+        # 检查是否包含文件
+        if 'file' not in request.files:
+            return jsonify({'error': 'No file part'}), 400
+        file = request.files['file']
+
+        # 如果没有文件或者文件名为空
+        if file.filename == '':
+            return jsonify({'error': 'No selected file'}), 400
+
+        # 检查文件类型是否允许
+        if file and allowed_file(file.filename):
+            # 获取数据集的元数据
+            dataset_name = request.form.get('dataset_name')
+            dataset_description = request.form.get('dataset_description', 'No description provided')
+            dataset_type = request.form.get('dataset_type')  # 新增字段:数据集类型
+
+            # 校验 dataset_type 是否存在
+            if not dataset_type:
+                return jsonify({'error': 'Dataset type is required'}), 400
+
+            # 创建 Dataset 实体并保存到数据库
+            new_dataset = Dataset(
+                DatasetName=dataset_name,
+                DatasetDescription=dataset_description,
+                RowCount=0,  # 初步创建数据集时,行数先置为0
+                Status='pending',  # 状态默认为 'pending'
+                DatasetType=dataset_type  # 保存数据集类型
+            )
+            db.session.add(new_dataset)
+            db.session.commit()
+
+            # 获取数据集的 ID
+            dataset_id = new_dataset.DatasetID
+
+            # 保存文件时使用数据库的 DatasetID 作为文件名
+            unique_filename = f"dataset_{dataset_id}.xlsx"
+            upload_folder = current_app.config['UPLOAD_FOLDER']
+            file_path = os.path.join(upload_folder, unique_filename)
+
+            # 保存文件
+            file.save(file_path)
+
+            # 读取 Excel 文件内容
+            dataset_df = pd.read_excel(file_path)
+
+            # 更新数据集的行数
+            row_count = len(dataset_df)
+            new_dataset.RowCount = row_count
+            new_dataset.Status = 'processed'  # 状态更新为 processed
+            db.session.commit()
+
+            # 动态创建数据表
+            columns = {}
+            for col in dataset_df.columns:
+                if dataset_df[col].dtype == 'int64':
+                    columns[col] = 'int'
+                elif dataset_df[col].dtype == 'float64':
+                    columns[col] = 'float'
+                else:
+                    columns[col] = 'str'
+
+            # 创建新表格(动态表格)
+            dynamic_table_class = create_dynamic_table(dataset_id, columns)
+
+            # 创建新的数据库会话
+            Session = sessionmaker(bind=db.engine)
+            session = Session()
+
+            # 将每一行数据插入到动态创建的表格中
+            for _, row in dataset_df.iterrows():
+                record_data = row.to_dict()
+                # 将数据插入到新表格中
+                session.execute(dynamic_table_class.__table__.insert(), [record_data])
+
+            session.commit()
+            session.close()
+
+            return jsonify({
+                'message': f'Dataset {dataset_name} uploaded successfully!',
+                'dataset_id': new_dataset.DatasetID,
+                'filename': unique_filename
+            }), 201
+
+        else:
+            return jsonify({'error': 'Invalid file type'}), 400
+
+    except Exception as e:
+        return jsonify({'error': str(e)}), 500
+
+
+@bp.route('/tables', methods=['GET'])
+def list_tables():
+    engine = db.engine  # 使用 db 实例的 engine
+    inspector = Inspector.from_engine(engine)  # 创建 Inspector 对象
+    table_names = inspector.get_table_names()  # 获取所有表名
+    return jsonify(table_names)  # 以 JSON 形式返回表名列表
+
+@bp.route('/models/<int:model_id>', methods=['GET'])
+def get_model(model_id):
+    try:
+        model = Model.query.filter_by(ModelID=model_id).first()
+        if model:
+            return jsonify({
+                'ModelID': model.ModelID,
+                'ModelName': model.ModelName,
+                'ModelType': model.ModelType,
+                'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
+                'Description': model.Description
+            })
+        else:
+            return jsonify({'message': 'Model not found'}), 404
+    except Exception as e:
+        return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
+
+@bp.route('/models', methods=['GET'])
+def get_all_models():
+    try:
+        models = Model.query.all()  # 获取所有模型数据
+        if models:
+            result = [
+                {
+                    'ModelID': model.ModelID,
+                    'ModelName': model.ModelName,
+                    'ModelType': model.ModelType,
+                    'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
+                    'Description': model.Description
+                }
+                for model in models
+            ]
+            return jsonify(result)
+        else:
+            return jsonify({'message': 'No models found'}), 404
+    except Exception as e:
+        return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
 
-def get_db():
-    db = getattr(g, '_database', None)
-    if db is None:
-        db = g._database = sqlite3.connect(DATABASE)
-    return db
+@bp.route('/model-parameters', methods=['GET'])
+def get_all_model_parameters():
+    try:
+        parameters = ModelParameters.query.all()  # 获取所有参数数据
+        if parameters:
+            result = [
+                {
+                    'ParamID': param.ParamID,
+                    'ModelID': param.ModelID,
+                    'ParamName': param.ParamName,
+                    'ParamValue': param.ParamValue
+                }
+                for param in parameters
+            ]
+            return jsonify(result)
+        else:
+            return jsonify({'message': 'No parameters found'}), 404
+    except Exception as e:
+        return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
+
+@bp.route('/models/<int:model_id>/parameters', methods=['GET'])
+def get_model_parameters(model_id):
+    try:
+        model = Model.query.filter_by(ModelID=model_id).first()
+        if model:
+            # 获取该模型的所有参数
+            parameters = [
+                {
+                    'ParamID': param.ParamID,
+                    'ParamName': param.ParamName,
+                    'ParamValue': param.ParamValue
+                }
+                for param in model.parameters
+            ]
+            
+            # 返回模型参数信息
+            return jsonify({
+                'ModelID': model.ModelID,
+                'ModelName': model.ModelName,
+                'ModelType': model.ModelType,
+                'CreatedAt': model.CreatedAt.strftime('%Y-%m-%d %H:%M:%S'),
+                'Description': model.Description,
+                'Parameters': parameters
+            })
+        else:
+            return jsonify({'message': 'Model not found'}), 404
+    except Exception as e:
+        return jsonify({'error': 'Internal server error', 'message': str(e)}), 500
 
 
 @bp.route('/predict', methods=['POST'])

+ 57 - 4
api/app/utils.py

@@ -1,4 +1,57 @@
-# 工具模块,用于存放一些工具函数:数据预处理、模型评估等
-def preprocess_data(data):
-    # 在此进行数据清理和转换
-    return data
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import Column, Integer, String, Float, DateTime
+from sqlalchemy import create_engine
+import uuid
+from datetime import datetime, timezone
+
+Base = declarative_base()
+
+def create_dynamic_table(dataset_id, columns):
+    """动态创建数据表"""
+    # 动态构建列
+    dynamic_columns = {
+        'id': Column(Integer, primary_key=True, autoincrement=True)  # 为每个表添加一个主键
+    }
+
+    # 根据 columns 字典动态创建字段
+    for col_name, col_type in columns.items():
+        if col_type == 'str':
+            dynamic_columns[col_name] = Column(String(255))
+        elif col_type == 'int':
+            dynamic_columns[col_name] = Column(Integer)
+        elif col_type == 'float':
+            dynamic_columns[col_name] = Column(Float)
+        elif col_type == 'datetime':
+            dynamic_columns[col_name] = Column(DateTime)
+
+    # 动态生成模型类,表名使用 dataset_{dataset_id}
+    table_name = f"dataset_{dataset_id}"
+
+    # 在生成的类中添加 `__tablename__`
+    dynamic_columns['__tablename__'] = table_name
+
+    # 动态创建类
+    dynamic_class = type(table_name, (Base,), dynamic_columns)
+
+    # 打印调试信息
+    print("table_name:", table_name)
+    print("dynamic_columns:", dynamic_columns)
+
+    # 创建数据库引擎
+    engine = create_engine('sqlite:///SoilAcidification.db')  # 这里需要替换为你的数据库引擎
+    Base.metadata.create_all(engine)  # 创建所有表格
+
+    return dynamic_class
+
+# 判断文件类型是否允许
+def allowed_file(filename):
+    ALLOWED_EXTENSIONS = {'xlsx', 'xls'}
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+# 生成唯一文件名
+def generate_unique_filename(filename):
+    # 获取文件的扩展名
+    ext = filename.rsplit('.', 1)[1].lower()
+    # 使用 UUID 和当前时间戳生成唯一文件名(使用 UTC 时区)
+    unique_filename = f"{uuid.uuid4().hex}_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}.{ext}"
+    return unique_filename