Ver Fonte

实现数据集上传接口

drggboy há 5 meses atrás
pai
commit
3e681dc1a7
5 ficheiros alterados com 184 adições e 12 exclusões
  1. BIN
      api/SoilAcidification.db
  2. 1 1
      api/app/config.py
  3. 27 3
      api/app/database_models.py
  4. 99 4
      api/app/routes.py
  5. 57 4
      api/app/utils.py

BIN
api/SoilAcidification.db


+ 1 - 1
api/app/config.py

@@ -6,4 +6,4 @@ class Configs:
     MODEL_PATH = 'model_optimize/pkl/RF_filt.pkl'
     DATABASE = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'SoilAcidification.db')
     SQLALCHEMY_DATABASE_URI = f'sqlite:///{DATABASE}'
-    UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
+    UPLOAD_FOLDER = 'uploads/datasets'

+ 27 - 3
api/app/database_models.py

@@ -1,4 +1,5 @@
 from . import db
+from datetime import datetime
 
 class Model(db.Model):
     __tablename__ = 'Models'
@@ -9,8 +10,6 @@ class Model(db.Model):
     CreatedAt = db.Column(db.TIMESTAMP, default=db.func.current_timestamp())
     Description = db.Column(db.Text)
     
-    # # 添加与 ModelParameters 表的关系
-    # parameters = db.relationship('ModelParameters', backref='model', lazy=True)
 
 class ModelParameters(db.Model):
     __tablename__ = 'ModelParameters'  # 指定表名
@@ -21,4 +20,29 @@ class ModelParameters(db.Model):
     ParamValue = db.Column(db.Text, nullable=False)  # 参数值
 
     # 定义反向关系
-    model = db.relationship('Model', backref=db.backref('parameters', lazy=True))
+    model = db.relationship('Model', backref=db.backref('parameters', lazy=True))
+
+class Dataset(db.Model):
+    __tablename__ = 'Datasets'
+
+    DatasetID = db.Column(db.Integer, primary_key=True)  # 数据集ID
+    DatasetName = db.Column(db.String(255), nullable=False)  # 数据集名称
+    DatasetDescription = db.Column(db.Text, nullable=True)  # 数据集描述 (可选)
+    UploadedAt = db.Column(db.TIMESTAMP, default=datetime.utcnow)  # 上传时间
+    RowCount = db.Column(db.Integer, nullable=False)  # 数据集行数(条数)
+    Status = db.Column(db.String(50), default='pending')  # 数据集状态 (pending, processed)
+    DatasetType = db.Column(db.String(50), nullable=False)  # 数据集类型 (反酸模型训练, 降酸模型训练等)
+
+    def __repr__(self):
+        return f'<Dataset {self.DatasetName}>'
+
+    def to_dict(self):
+        return {
+            'DatasetID': self.DatasetID,
+            'DatasetName': self.DatasetName,
+            'DatasetDescription': self.DatasetDescription,
+            'UploadedAt': self.UploadedAt.strftime('%Y-%m-%d %H:%M:%S'),
+            'RowCount': self.RowCount,
+            'Status': self.Status,
+            'DatasetType': self.DatasetType
+        }

+ 99 - 4
api/app/routes.py

@@ -1,16 +1,111 @@
 import sqlite3
 
-from flask import Blueprint, request, jsonify, g, current_app
-from .model import predict
+from flask import Blueprint, request, jsonify, current_app
+from .model import predict, train_and_save_model
 import pandas as pd
-from flask_sqlalchemy import SQLAlchemy
 from . import db  # 从 app 包导入 db 实例
 from sqlalchemy.engine.reflection import Inspector
-from .database_models import Model, ModelParameters
+from .database_models import Model, ModelParameters, Dataset
+import os
+from .utils import create_dynamic_table, allowed_file
+from sqlalchemy.orm import sessionmaker
 
 # 创建蓝图 (Blueprint),用于分离路由
 bp = Blueprint('routes', __name__)
 
+@bp.route('/upload-dataset', methods=['POST'])
+def upload_dataset():
+    try:
+        # 检查是否包含文件
+        if 'file' not in request.files:
+            return jsonify({'error': 'No file part'}), 400
+        file = request.files['file']
+
+        # 如果没有文件或者文件名为空
+        if file.filename == '':
+            return jsonify({'error': 'No selected file'}), 400
+
+        # 检查文件类型是否允许
+        if file and allowed_file(file.filename):
+            # 获取数据集的元数据
+            dataset_name = request.form.get('dataset_name')
+            dataset_description = request.form.get('dataset_description', 'No description provided')
+            dataset_type = request.form.get('dataset_type')  # 新增字段:数据集类型
+
+            # 校验 dataset_type 是否存在
+            if not dataset_type:
+                return jsonify({'error': 'Dataset type is required'}), 400
+
+            # 创建 Dataset 实体并保存到数据库
+            new_dataset = Dataset(
+                DatasetName=dataset_name,
+                DatasetDescription=dataset_description,
+                RowCount=0,  # 初步创建数据集时,行数先置为0
+                Status='pending',  # 状态默认为 'pending'
+                DatasetType=dataset_type  # 保存数据集类型
+            )
+            db.session.add(new_dataset)
+            db.session.commit()
+
+            # 获取数据集的 ID
+            dataset_id = new_dataset.DatasetID
+
+            # 保存文件时使用数据库的 DatasetID 作为文件名
+            unique_filename = f"dataset_{dataset_id}.xlsx"
+            upload_folder = current_app.config['UPLOAD_FOLDER']
+            file_path = os.path.join(upload_folder, unique_filename)
+
+            # 保存文件
+            file.save(file_path)
+
+            # 读取 Excel 文件内容
+            dataset_df = pd.read_excel(file_path)
+
+            # 更新数据集的行数
+            row_count = len(dataset_df)
+            new_dataset.RowCount = row_count
+            new_dataset.Status = 'processed'  # 状态更新为 processed
+            db.session.commit()
+
+            # 动态创建数据表
+            columns = {}
+            for col in dataset_df.columns:
+                if dataset_df[col].dtype == 'int64':
+                    columns[col] = 'int'
+                elif dataset_df[col].dtype == 'float64':
+                    columns[col] = 'float'
+                else:
+                    columns[col] = 'str'
+
+            # 创建新表格(动态表格)
+            dynamic_table_class = create_dynamic_table(dataset_id, columns)
+
+            # 创建新的数据库会话
+            Session = sessionmaker(bind=db.engine)
+            session = Session()
+
+            # 将每一行数据插入到动态创建的表格中
+            for _, row in dataset_df.iterrows():
+                record_data = row.to_dict()
+                # 将数据插入到新表格中
+                session.execute(dynamic_table_class.__table__.insert(), [record_data])
+
+            session.commit()
+            session.close()
+
+            return jsonify({
+                'message': f'Dataset {dataset_name} uploaded successfully!',
+                'dataset_id': new_dataset.DatasetID,
+                'filename': unique_filename
+            }), 201
+
+        else:
+            return jsonify({'error': 'Invalid file type'}), 400
+
+    except Exception as e:
+        return jsonify({'error': str(e)}), 500
+
+
 @bp.route('/tables', methods=['GET'])
 def list_tables():
     engine = db.engine  # 使用 db 实例的 engine

+ 57 - 4
api/app/utils.py

@@ -1,4 +1,57 @@
-# 工具模块,用于存放一些工具函数:数据预处理、模型评估等
-def preprocess_data(data):
-    # 在此进行数据清理和转换
-    return data
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import Column, Integer, String, Float, DateTime
+from sqlalchemy import create_engine
+import uuid
+from datetime import datetime, timezone
+
+Base = declarative_base()
+
+def create_dynamic_table(dataset_id, columns):
+    """动态创建数据表"""
+    # 动态构建列
+    dynamic_columns = {
+        'id': Column(Integer, primary_key=True, autoincrement=True)  # 为每个表添加一个主键
+    }
+
+    # 根据 columns 字典动态创建字段
+    for col_name, col_type in columns.items():
+        if col_type == 'str':
+            dynamic_columns[col_name] = Column(String(255))
+        elif col_type == 'int':
+            dynamic_columns[col_name] = Column(Integer)
+        elif col_type == 'float':
+            dynamic_columns[col_name] = Column(Float)
+        elif col_type == 'datetime':
+            dynamic_columns[col_name] = Column(DateTime)
+
+    # 动态生成模型类,表名使用 dataset_{dataset_id}
+    table_name = f"dataset_{dataset_id}"
+
+    # 在生成的类中添加 `__tablename__`
+    dynamic_columns['__tablename__'] = table_name
+
+    # 动态创建类
+    dynamic_class = type(table_name, (Base,), dynamic_columns)
+
+    # 打印调试信息
+    print("table_name:", table_name)
+    print("dynamic_columns:", dynamic_columns)
+
+    # 创建数据库引擎
+    engine = create_engine('sqlite:///SoilAcidification.db')  # 这里需要替换为你的数据库引擎
+    Base.metadata.create_all(engine)  # 创建所有表格
+
+    return dynamic_class
+
+# 判断文件类型是否允许
+def allowed_file(filename):
+    ALLOWED_EXTENSIONS = {'xlsx', 'xls'}
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+# 生成唯一文件名
+def generate_unique_filename(filename):
+    # 获取文件的扩展名
+    ext = filename.rsplit('.', 1)[1].lower()
+    # 使用 UUID 和当前时间戳生成唯一文件名(使用 UTC 时区)
+    unique_filename = f"{uuid.uuid4().hex}_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}.{ext}"
+    return unique_filename