tasks.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Celery任务定义文件
  3. """
  4. from celery import current_task
  5. from app.celery_app import celery
  6. import time
  7. from .database_models import Models, Datasets
  8. from .model import train_and_save_model, calculate_model_score
  9. import logging
  10. from . import create_app, db
  11. import os
  12. @celery.task(name='train_model_task')
  13. def train_model_task(model_type, model_name, model_description, data_type, dataset_id=None):
  14. """
  15. 异步训练模型任务
  16. Args:
  17. model_type: 模型类型
  18. model_name: 模型名称
  19. model_description: 模型描述
  20. data_type: 数据类型
  21. dataset_id: 数据集ID
  22. Returns:
  23. dict: 训练结果
  24. """
  25. # 创建应用上下文
  26. app = create_app()
  27. session = None
  28. try:
  29. with app.app_context():
  30. # 确保数据库文件存在
  31. if not os.path.exists(app.config['DATABASE']):
  32. db.create_all()
  33. # 创建新的数据库引擎和会话
  34. from sqlalchemy import create_engine
  35. from sqlalchemy.orm import scoped_session, sessionmaker
  36. engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI'])
  37. session_factory = sessionmaker(bind=engine)
  38. Session = scoped_session(session_factory)
  39. session = Session()
  40. # 添加日志记录
  41. logging.info(f"Starting auto-training for {data_type} dataset with {model_type} model")
  42. # 调用训练函数
  43. model_name, model_id = train_and_save_model(
  44. session,
  45. model_type,
  46. model_name,
  47. model_description,
  48. data_type,
  49. dataset_id
  50. )
  51. # 计算模型评分
  52. if model_id:
  53. model_info = session.query(Models).filter(Models.ModelID == model_id).first()
  54. if model_info:
  55. score = calculate_model_score(model_info)
  56. # 更新模型评分
  57. model_info.Performance_score = score
  58. session.commit()
  59. logging.info(f"Auto-training completed successfully. Model ID: {model_id}")
  60. # 模拟训练过程
  61. # 训练成功后更新状态
  62. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  63. if not dataset:
  64. raise ValueError("Dataset not found")
  65. # 如果训练成功
  66. dataset.Status = 'training_success'
  67. session.commit()
  68. return {
  69. 'status': 'success',
  70. 'model_name': model_name,
  71. 'model_id': model_id,
  72. 'model_score': score,
  73. 'message': 'Model trained successfully'
  74. }
  75. except Exception as e:
  76. logging.error(f"Failed to train model: {str(e)}")
  77. # 更新任务状态为 FAILURE
  78. current_task.update_state(state='FAILURE', meta={'error': str(e)})
  79. # 更新数据集状态为训练失败
  80. if session:
  81. dataset = session.query(Datasets).filter_by(Dataset_ID=dataset_id).first()
  82. if dataset:
  83. dataset.Status = 'training_failed'
  84. session.commit()
  85. raise # 重新抛出异常以确保 Celery 记录任务失败
  86. finally:
  87. if session:
  88. session.close()
  89. Session.remove() # 清理scoped session