123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- """
- Celery任务定义文件
- """
- from celery import current_task
- from app.celery_app import celery
- import time
- from .database_models import Models, Datasets
- from .model import train_and_save_model, calculate_model_score
- import logging
- from . import create_app, db
- import os
- @celery.task(name='train_model_task')
- def train_model_task(model_type, model_name, model_description, data_type, dataset_id=None):
- """
- 异步训练模型任务
-
- Args:
- model_type: 模型类型
- model_name: 模型名称
- model_description: 模型描述
- data_type: 数据类型
- dataset_id: 数据集ID
-
- Returns:
- dict: 训练结果
- """
- # 创建应用上下文
- app = create_app()
- session = None
-
- try:
- with app.app_context():
- # 确保数据库文件存在
- if not os.path.exists(app.config['DATABASE']):
- db.create_all()
-
- # 创建新的数据库引擎和会话
- from sqlalchemy import create_engine
- from sqlalchemy.orm import scoped_session, sessionmaker
-
- engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI'])
- session_factory = sessionmaker(bind=engine)
- Session = scoped_session(session_factory)
- session = Session()
-
- # 添加日志记录
- logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
-
- # 调用训练函数
- model_name, model_id = train_and_save_model(
- session,
- model_type,
- model_name,
- model_description,
- data_type,
- dataset_id
- )
- # 计算模型评分
- if model_id:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if model_info:
- score = calculate_model_score(model_info)
- # 更新模型评分
- model_info.Performance_score = score
- session.commit()
-
- logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
-
- # 模拟训练过程
- # 训练成功后更新状态
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if not dataset:
- raise ValueError("Dataset not found")
- # 如果训练成功
- dataset.Status = 'training_success'
- session.commit()
- return {
- 'status': 'success',
- 'model_name': model_name,
- 'model_id': model_id,
- 'model_score': score,
- 'message': 'Model trained successfully'
- }
-
- except Exception as e:
- logging.error(f"Failed to train model: {str(e)}")
-
- # 更新任务状态为 FAILURE
- current_task.update_state(state='FAILURE', meta={'error': str(e)})
-
- # 更新数据集状态为训练失败
- if session:
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if dataset:
- dataset.Status = 'training_failed'
- session.commit()
-
- raise # 重新抛出异常以确保 Celery 记录任务失败
- finally:
- if session:
- session.close()
- Session.remove() # 清理scoped session
|