tasks.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. Celery任务定义文件
  3. """
  4. from celery import current_task
  5. from app.celery_app import celery
  6. import time
  7. from .database_models import Models, Datasets
  8. from .model import train_and_save_model, calculate_model_score
  9. import logging
  10. from . import create_app, db
  11. import os
  12. from sqlalchemy.orm import sessionmaker
  13. @celery.task(name='train_model_task', bind=True)
  14. def train_model_task(self, model_type, model_name, model_description, data_type, dataset_id=None):
  15. """
  16. 异步训练模型任务
  17. Args:
  18. model_type: 模型类型
  19. model_name: 模型名称
  20. model_description: 模型描述
  21. data_type: 数据类型
  22. dataset_id: 数据集ID
  23. Returns:
  24. dict: 训练结果
  25. """
  26. # 创建应用上下文
  27. app = create_app()
  28. try:
  29. with app.app_context():
  30. Session = sessionmaker(bind=db.engine)
  31. session = Session()
  32. try:
  33. # 初始化 dataset 变量
  34. dataset = None
  35. # 如果指定了dataset_id,验证数据集存在
  36. if dataset_id:
  37. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  38. if not dataset:
  39. raise ValueError(f"Dataset with ID {dataset_id} not found")
  40. # 添加日志记录
  41. logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
  42. # 调用训练函数
  43. model_name, model_id, dataset_id = train_and_save_model(
  44. session,
  45. model_type,
  46. model_name,
  47. model_description,
  48. data_type,
  49. dataset_id
  50. )
  51. # 计算模型评分
  52. score = None
  53. if model_id:
  54. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  55. if model_info:
  56. score = calculate_model_score(model_info)
  57. # 更新模型评分
  58. model_info.Performance_score = score
  59. session.commit()
  60. logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
  61. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  62. dataset.Status = 'Training_success'
  63. session.commit()
  64. return {
  65. 'status': 'success',
  66. 'model_name': model_name,
  67. 'model_id': model_id,
  68. 'model_score': score,
  69. 'message': 'Model trained successfully'
  70. }
  71. finally:
  72. session.close()
  73. except Exception as e:
  74. # 使用 Celery 的 update_state 来更新任务状态
  75. self.update_state(
  76. state='FAILURE',
  77. meta={
  78. 'exc_type': type(e).__name__,
  79. 'exc_message': str(e),
  80. 'error': f'Training failed: {str(e)}'
  81. }
  82. )
  83. raise # 重新抛出异常以便 Celery 可以正确处理