Эх сурвалжийг харах

Merge branch 'zsj_devp' of Ding/AcidificationModel into master

合并代码,解决冲突
drggboy 3 сар өмнө
parent
commit
88a7e82efb

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+api/__pycache__/

+ 2 - 1
api/.gitignore

@@ -1,3 +1,4 @@
 app/__pycache__
 .idea
-model_optimize/__pycache__
+model_optimize/__pycache__
+__pycache__

+ 2 - 1
api/app/config.py

@@ -14,7 +14,8 @@ class Config:
     CELERY_RESULT_BACKEND = 'redis://localhost:6379/0'
     
     # 定义阈值配置
-    THRESHOLD = 20
+    DEFAULT_THRESHOLD = 30  # 默认阈值
+    THRESHOLD = DEFAULT_THRESHOLD  # 当前使用的阈值
     
     # 定义自动训练默认模型类型配置
     DEFAULT_MODEL_TYPE = 'RandomForest'

+ 1 - 1
api/app/model.py

@@ -94,7 +94,7 @@ def train_and_save_model(session, model_type, model_name, model_description, dat
 
         # 所有操作成功后,手动提交事务
         session.commit()
-        return model_name, model_id
+        return model_name, model_id, dataset_id
     except Exception as e:
         # 如果在任何阶段出现异常,回滚事务
         session.rollback()

+ 144 - 15
api/app/routes.py

@@ -1,9 +1,5 @@
 import sqlite3
-from flask import current_app, send_file
-from werkzeug.security import generate_password_hash, check_password_hash
-from flask import Blueprint, request, jsonify, current_app as app
-from werkzeug.utils import secure_filename
-from io import BytesIO
+from flask import Blueprint, request, jsonify, current_app
 from .model import predict, train_and_save_model, calculate_model_score
 import pandas as pd
 from . import db  # 从 app 包导入 db 实例
@@ -89,6 +85,10 @@ def check_and_trigger_training(session, dataset_type, dataset_df):
 
 @bp.route('/upload-dataset', methods=['POST'])
 def upload_dataset():
+    # 创建 session
+    Session = sessionmaker(bind=db.engine)
+    session = Session()
+    
     try:
         if 'file' not in request.files:
             return jsonify({'error': 'No file part'}), 400
@@ -102,9 +102,6 @@ def upload_dataset():
         if not dataset_type:
             return jsonify({'error': 'Dataset type is required'}), 400
 
-        # 创建 sessionmaker 实例
-        Session = sessionmaker(bind=db.engine)
-        session = Session()
         new_dataset = Datasets(
             Dataset_name=dataset_name,
             Dataset_description=dataset_description,
@@ -161,8 +158,11 @@ def upload_dataset():
         session.rollback()
         logging.error('Failed to process the dataset upload:', exc_info=True)
         return jsonify({'error': str(e)}), 500
+        
     finally:
-        session.close()
+        # 确保 session 总是被关闭
+        if session:
+            session.close()
 
 
 @bp.route('/train-and-save-model', methods=['POST'])
@@ -740,13 +740,26 @@ def train_model_async():
                 'error': 'Missing required parameters'
             }), 400
             
+        # 如果提供了dataset_id,验证数据集是否存在
+        if dataset_id:
+            Session = sessionmaker(bind=db.engine)
+            session = Session()
+            try:
+                dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
+                if not dataset:
+                    return jsonify({
+                        'error': f'Dataset with ID {dataset_id} not found'
+                    }), 404
+            finally:
+                session.close()
+            
         # 启动异步任务
         task = train_model_task.delay(
-            model_type,
-            model_name,
-            model_description,
-            data_type,
-            dataset_id
+            model_type=model_type,
+            model_name=model_name,
+            model_description=model_description,
+            data_type=data_type,
+            dataset_id=dataset_id
         )
         
         # 返回任务ID
@@ -1264,4 +1277,120 @@ def download_template():
     except Exception as e:
         logger.error(f"Failed to generate template: {e}", exc_info=True)
         return jsonify({'error': '生成模板文件失败'}), 500
-    
+
+@bp.route('/update-threshold', methods=['POST'])
+def update_threshold():
+    """
+    更新训练阈值的API接口
+    
+    @body_param threshold: 新的阈值值(整数)
+    @return: JSON响应
+    """
+    try:
+        data = request.get_json()
+        new_threshold = data.get('threshold')
+        
+        # 验证新阈值
+        if not isinstance(new_threshold, (int, float)) or new_threshold <= 0:
+            return jsonify({
+                'error': '无效的阈值值,必须为正数'
+            }), 400
+            
+        # 更新当前应用的阈值配置
+        current_app.config['THRESHOLD'] = int(new_threshold)
+        
+        return jsonify({
+            'success': True,
+            'message': f'阈值已更新为 {new_threshold}',
+            'new_threshold': new_threshold
+        })
+        
+    except Exception as e:
+        logging.error(f"更新阈值失败: {str(e)}")
+        return jsonify({
+            'error': f'更新阈值失败: {str(e)}'
+        }), 500
+
+
+@bp.route('/get-threshold', methods=['GET'])
+def get_threshold():
+    """
+    获取当前训练阈值的API接口
+    
+    @return: JSON响应
+    """
+    try:
+        current_threshold = current_app.config['THRESHOLD']
+        default_threshold = current_app.config['DEFAULT_THRESHOLD']
+        
+        return jsonify({
+            'current_threshold': current_threshold,
+            'default_threshold': default_threshold
+        })
+        
+    except Exception as e:
+        logging.error(f"获取阈值失败: {str(e)}")
+        return jsonify({
+            'error': f'获取阈值失败: {str(e)}'
+        }), 500
+
+@bp.route('/set-current-dataset/<string:data_type>/<int:dataset_id>', methods=['POST'])
+def set_current_dataset(data_type, dataset_id):
+    """
+    将指定数据集设置为current数据集
+    
+    @param data_type: 数据集类型 ('reduce' 或 'reflux')
+    @param dataset_id: 要设置为current的数据集ID
+    @return: JSON响应
+    """
+    Session = sessionmaker(bind=db.engine)
+    session = Session()
+    
+    try:
+        # 验证数据集存在且类型匹配
+        dataset = session.query(Datasets)\
+            .filter_by(Dataset_ID=dataset_id, Dataset_type=data_type)\
+            .first()
+            
+        if not dataset:
+            return jsonify({
+                'error': f'未找到ID为 {dataset_id} 且类型为 {data_type} 的数据集'
+            }), 404
+            
+        # 根据数据类型选择表
+        if data_type == 'reduce':
+            table = CurrentReduce
+            table_name = 'current_reduce'
+        elif data_type == 'reflux':
+            table = CurrentReflux
+            table_name = 'current_reflux'
+        else:
+            return jsonify({'error': '无效的数据集类型'}), 400
+            
+        # 清空current表
+        session.query(table).delete()
+        
+        # 重置自增主键计数器
+        session.execute(text(f"DELETE FROM sqlite_sequence WHERE name='{table_name}'"))
+        
+        # 从指定数据集复制数据到current表
+        dataset_table_name = f"dataset_{dataset_id}"
+        copy_sql = text(f"INSERT INTO {table_name} SELECT * FROM {dataset_table_name}")
+        session.execute(copy_sql)
+        
+        session.commit()
+        
+        return jsonify({
+            'message': f'{data_type} current数据集已设置为数据集 ID: {dataset_id}',
+            'dataset_id': dataset_id,
+            'dataset_name': dataset.Dataset_name,
+            'row_count': dataset.Row_count
+        }), 200
+        
+    except Exception as e:
+        session.rollback()
+        logger.error(f'设置current数据集失败: {str(e)}')
+        return jsonify({'error': str(e)}), 500
+        
+    finally:
+        session.close()

+ 61 - 73
api/app/tasks.py

@@ -2,18 +2,17 @@
 Celery任务定义文件
 """
 from celery import current_task
-
 from app.celery_app import celery
 import time
-
 from .database_models import Models, Datasets
 from .model import train_and_save_model, calculate_model_score
 import logging
 from . import create_app, db
 import os
+from sqlalchemy.orm import sessionmaker
 
-@celery.task(name='train_model_task')
-def train_model_task(model_type, model_name, model_description, data_type, dataset_id=None):
+@celery.task(name='train_model_task', bind=True)
+def train_model_task(self, model_type, model_name, model_description, data_type, dataset_id=None):
     """
     异步训练模型任务
     
@@ -29,81 +28,70 @@ def train_model_task(model_type, model_name, model_description, data_type, datas
     """
     # 创建应用上下文
     app = create_app()
-    session = None
     
     try:
         with app.app_context():
-            # 确保数据库文件存在
-            if not os.path.exists(app.config['DATABASE']):
-                db.create_all()
-            
-            # 创建新的数据库引擎和会话
-            from sqlalchemy import create_engine
-            from sqlalchemy.orm import scoped_session, sessionmaker
-            
-            engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI'])
-            session_factory = sessionmaker(bind=engine)
-            Session = scoped_session(session_factory)
+            Session = sessionmaker(bind=db.engine)
             session = Session()
             
-            # 添加日志记录
-            logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
-            
-            # 调用训练函数
-            model_name, model_id = train_and_save_model(
-                session, 
-                model_type, 
-                model_name, 
-                model_description, 
-                data_type, 
-                dataset_id
-            )
-
-            # 计算模型评分
-            if model_id:
-                model_info = session.query(Models).filter(Models.ModelID == model_id).first()
-                if model_info:
-                    score = calculate_model_score(model_info)
-                    # 更新模型评分
-                    model_info.Performance_score = score
-                    session.commit()
+            try:
+                # 初始化 dataset 变量
+                dataset = None
+                
+                # 如果指定了dataset_id,验证数据集存在
+                if dataset_id:
+                    dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
+                    if not dataset:
+                        raise ValueError(f"Dataset with ID {dataset_id} not found")
+                        
+                # 添加日志记录
+                logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
+                
+                # 调用训练函数
+                model_name, model_id, dataset_id = train_and_save_model(
+                    session, 
+                    model_type, 
+                    model_name, 
+                    model_description, 
+                    data_type, 
+                    dataset_id
+                )
 
-            
-            logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
-            
-            # 模拟训练过程
-            # 训练成功后更新状态
-            dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
-            if not dataset:
-                raise ValueError("Dataset not found")
+                # 计算模型评分
+                score = None
+                if model_id:
+                    model_info = session.query(Models).filter(Models.ModelID == model_id).first()
+                    if model_info:
+                        score = calculate_model_score(model_info)
+                        # 更新模型评分
+                        model_info.Performance_score = score
+                        session.commit()
 
-            # 如果训练成功
-            dataset.Status = 'training_success'
-            session.commit()
+                logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
+                
+                dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
+                dataset.Status = 'Training_success'
+                session.commit()
 
-            return {
-                'status': 'success',
-                'model_name': model_name,
-                'model_id': model_id,
-                'model_score': score,
-                'message': 'Model trained successfully'
-            }
-            
+                return {
+                    'status': 'success',
+                    'model_name': model_name,
+                    'model_id': model_id,
+                    'model_score': score,
+                    'message': 'Model trained successfully'
+                }
+                
+            finally:
+                session.close()
+                
     except Exception as e:
-        logging.error(f"Failed to train model: {str(e)}")
-        
-        # 更新任务状态为 FAILURE
-        current_task.update_state(state='FAILURE', meta={'error': str(e)})
-        
-        # 更新数据集状态为训练失败
-        if session:
-            dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
-            if dataset:
-                dataset.Status = 'training_failed'
-                session.commit()
-        
-        raise  # 重新抛出异常以确保 Celery 记录任务失败
-    finally:
-        if session:
-            session.close()
-            Session.remove()  # 清理scoped session 
+        # 使用 Celery 的 update_state 来更新任务状态
+        self.update_state(
+            state='FAILURE',
+            meta={
+                'exc_type': type(e).__name__,
+                'exc_message': str(e),
+                'error': f'Training failed: {str(e)}'
+            }
+        )
+        raise  # 重新抛出异常以便 Celery 可以正确处理 

+ 7 - 2
api/model_optimize/data_increase.py

@@ -2,8 +2,7 @@
 import os
 import pandas as pd
 import numpy as np
-from PIL import Image
-from model_saver import save_model
+
 
 # 机器学习模型导入
 from sklearn.ensemble import RandomForestRegressor
@@ -30,6 +29,7 @@ from pathlib import Path
 
 # 定义数据集配置
 DATASET_CONFIGS = {
+    # 土壤反酸数据:64个样本,9个特征(包含delta_ph),目标 105_day_ph
     'soil_acid_9features': {
         'file_path': 'model_optimize/data/data_filt.xlsx',
         'x_columns': range(1, 10),  # 9个特征(包含delta_ph)
@@ -47,6 +47,7 @@ DATASET_CONFIGS = {
         ],
         'target_name': 'target_ph'
     },
+    # 土壤反酸数据:64个样本,8个特征,目标 delta_ph
     'soil_acid_8features': {
         'file_path': 'model_optimize/data/data_filt - 副本.xlsx',
         'x_columns': range(1, 9),  # 8个特征
@@ -63,6 +64,7 @@ DATASET_CONFIGS = {
         ],
         'target_name': 'target_ph'
     },
+    # 土壤反酸数据:64个样本,8个特征,目标 delta_ph
     'soil_acid_8features_original': {
         'file_path': 'model_optimize/data/data_filt.xlsx',
         'x_columns': range(1, 9),  # 8个特征
@@ -79,6 +81,7 @@ DATASET_CONFIGS = {
         ],
         'target_name': 'target_ph'
     },
+    # 土壤反酸数据:60个样本(去除异常点),8个特征,目标 delta_ph
     'soil_acid_6features': {
         'file_path': 'model_optimize/data/data_reflux2.xlsx',
         'x_columns': range(0, 6),  # 6个特征
@@ -93,6 +96,7 @@ DATASET_CONFIGS = {
         ],
         'target_name': 'delta_ph'
     },
+    # 精准降酸数据:54个样本,5个特征,目标是1/b
     'acidity_reduce': {
         'file_path': 'model_optimize/data/Acidity_reduce.xlsx',
         'x_columns': range(1, 6),  # 5个特征
@@ -106,6 +110,7 @@ DATASET_CONFIGS = {
         ],
         'target_name': 'target'
     },
+    # 精准降酸数据(数据更新):54个样本,5个特征,目标是1/b
     'acidity_reduce_new': {
         'file_path': 'model_optimize/data/Acidity_reduce_new.xlsx',
         'x_columns': range(1, 6),  # 5个特征

+ 3 - 1
api/run.py

@@ -4,6 +4,8 @@ from app import create_app
 import os
 # 创建 Flask 应用
 app = create_app()
+
+# 使用 HTTPS
 context = ('ssl/cert.crt', 'ssl/cert.key')
 @app.before_request
 def force_https():
@@ -16,4 +18,4 @@ def force_https():
 # 启动服务器
 if __name__ == '__main__':
     app.run(host="0.0.0.0", port=5000, debug=True, ssl_context=context)
-
+    # app.run(debug=True)

BIN
api/uploads/datasets/dataset_29.xlsx


BIN
api/uploads/datasets/dataset_30.xlsx


BIN
api/uploads/datasets/dataset_31.xlsx


BIN
api/uploads/datasets/dataset_32.xlsx


BIN
api/uploads/datasets/dataset_34.xlsx


BIN
api/uploads/datasets/dataset_35.xlsx


BIN
api/uploads/datasets/dataset_36.xlsx


BIN
api/uploads/datasets/dataset_37.xlsx


BIN
api/uploads/datasets/dataset_38.xlsx


BIN
api/uploads/datasets/dataset_39.xlsx


BIN
api/uploads/datasets/dataset_40.xlsx


BIN
api/uploads/datasets/dataset_41.xlsx


BIN
api/uploads/datasets/dataset_42.xlsx