""" Celery任务定义文件 """ from celery import current_task from app.celery_app import celery import time from .database_models import Models, Datasets from .model import train_and_save_model, calculate_model_score import logging from . import create_app, db import os @celery.task(name='train_model_task') def train_model_task(model_type, model_name, model_description, data_type, dataset_id=None): """ 异步训练模型任务 Args: model_type: 模型类型 model_name: 模型名称 model_description: 模型描述 data_type: 数据类型 dataset_id: 数据集ID Returns: dict: 训练结果 """ # 创建应用上下文 app = create_app() session = None try: with app.app_context(): # 确保数据库文件存在 if not os.path.exists(app.config['DATABASE']): db.create_all() # 创建新的数据库引擎和会话 from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI']) session_factory = sessionmaker(bind=engine) Session = scoped_session(session_factory) session = Session() # 添加日志记录 logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model") # 调用训练函数 model_name, model_id = train_and_save_model( session, model_type, model_name, model_description, data_type, dataset_id ) # 计算模型评分 if model_id: model_info = session.query(Models).filter(Models.ModelID == model_id).first() if model_info: score = calculate_model_score(model_info) # 更新模型评分 model_info.Performance_score = score session.commit() logging.info(f"Auto-training completed successfully. Model ID: {model_id}") # 模拟训练过程 # 训练成功后更新状态 dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first() if not dataset: raise ValueError("Dataset not found") # 如果训练成功 dataset.Status = 'training_success' session.commit() return { 'status': 'success', 'model_name': model_name, 'model_id': model_id, 'model_score': score, 'message': 'Model trained successfully' } except Exception as e: logging.error(f"Failed to train model: {str(e)}") # 更新任务状态为 FAILURE current_task.update_state(state='FAILURE', meta={'error': str(e)}) # 更新数据集状态为训练失败 if session: dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first() if dataset: dataset.Status = 'training_failed' session.commit() raise # 重新抛出异常以确保 Celery 记录任务失败 finally: if session: session.close() Session.remove() # 清理scoped session