12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- """
- Celery任务定义文件
- """
- from celery import current_task
- from app.celery_app import celery
- import time
- from .database_models import Models, Datasets
- from .model import train_and_save_model, calculate_model_score
- import logging
- from . import create_app, db
- import os
- from sqlalchemy.orm import sessionmaker
- @celery.task(name='train_model_task', bind=True)
- def train_model_task(self, model_type, model_name, model_description, data_type, dataset_id=None):
- """
- 异步训练模型任务
-
- Args:
- model_type: 模型类型
- model_name: 模型名称
- model_description: 模型描述
- data_type: 数据类型
- dataset_id: 数据集ID
-
- Returns:
- dict: 训练结果
- """
- # 创建应用上下文
- app = create_app()
-
- try:
- with app.app_context():
- Session = sessionmaker(bind=db.engine)
- session = Session()
-
- try:
- # 初始化 dataset 变量
- dataset = None
-
- # 如果指定了dataset_id,验证数据集存在
- if dataset_id:
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- if not dataset:
- raise ValueError(f"Dataset with ID {dataset_id} not found")
-
- # 添加日志记录
- logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
-
- # 调用训练函数
- model_name, model_id, dataset_id = train_and_save_model(
- session,
- model_type,
- model_name,
- model_description,
- data_type,
- dataset_id
- )
- # 计算模型评分
- score = None
- if model_id:
- model_info = session.query(Models).filter(Models.ModelID == model_id).first()
- if model_info:
- score = calculate_model_score(model_info)
- # 更新模型评分
- model_info.Performance_score = score
- session.commit()
- logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
-
- dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
- dataset.Status = 'Training_success'
- session.commit()
- return {
- 'status': 'success',
- 'model_name': model_name,
- 'model_id': model_id,
- 'model_score': score,
- 'message': 'Model trained successfully'
- }
-
- finally:
- session.close()
-
- except Exception as e:
- # 使用 Celery 的 update_state 来更新任务状态
- self.update_state(
- state='FAILURE',
- meta={
- 'exc_type': type(e).__name__,
- 'exc_message': str(e),
- 'error': f'Training failed: {str(e)}'
- }
- )
- raise # 重新抛出异常以便 Celery 可以正确处理
|