Преглед изворни кода

实现训练前数据处理模块(未启用);
实现上传数据时数据去重

drggboy пре 3 месеци
родитељ
комит
983f33f897
3 измењених фајлова са 173 додато и 4 уклоњено
  1. 116 0
      api/app/data_cleaner.py
  2. 12 1
      api/app/model.py
  3. 45 3
      api/app/routes.py

+ 116 - 0
api/app/data_cleaner.py

@@ -0,0 +1,116 @@
+"""
+数据清理模块,提供各种数据清理和预处理功能
+"""
+import pandas as pd
+import numpy as np
+from sklearn.preprocessing import StandardScaler
+import logging
+
+logger = logging.getLogger(__name__)
+
+def remove_duplicates(df):
+    """
+    移除数据框中的重复行
+    
+    Args:
+        df: 输入数据框
+        
+    Returns:
+        tuple: (清理后的数据框, 移除的重复项数量)
+    """
+    original_count = len(df)
+    df_clean = df.drop_duplicates()
+    duplicates_removed = original_count - len(df_clean)
+    logger.info(f"移除了 {duplicates_removed} 个重复样本")
+    return df_clean, duplicates_removed
+
+def remove_outliers(df, method='iqr', threshold=1.5):
+    """
+    使用指定方法检测和移除异常值
+    
+    Args:
+        df: 输入数据框
+        method: 异常值检测方法 ('iqr', 'zscore')
+        threshold: 异常值判定阈值
+        
+    Returns:
+        tuple: (清理后的数据框, 移除的异常值数量)
+    """
+    original_count = len(df)
+    
+    if method == 'iqr':
+        Q1 = df.quantile(0.25)
+        Q3 = df.quantile(0.75)
+        IQR = Q3 - Q1
+        outlier_mask = ~((df < (Q1 - threshold * IQR)) | (df > (Q3 + threshold * IQR))).any(axis=1)
+        df_clean = df[outlier_mask]
+    
+    elif method == 'zscore':
+        from scipy import stats
+        z_scores = stats.zscore(df)
+        outlier_mask = ~(np.abs(z_scores) > threshold).any(axis=1)
+        df_clean = df[outlier_mask]
+    
+    outliers_removed = original_count - len(df_clean)
+    logger.info(f"使用 {method} 方法移除了 {outliers_removed} 个异常值")
+    return df_clean, outliers_removed
+
+def clean_dataset(df, target_column=None, remove_dups=False, handle_outliers=False, 
+                 outlier_method='iqr', outlier_threshold=1.5, normalize=False):
+    """
+    综合数据清理函数
+    
+    Args:
+        df: 输入数据框
+        target_column: 目标变量列名或索引
+        remove_dups: 是否移除重复项
+        handle_outliers: 是否处理异常值
+        outlier_method: 异常值检测方法
+        outlier_threshold: 异常值判定阈值
+        normalize: 是否标准化特征
+        
+    Returns:
+        tuple: (特征数据框, 目标变量, 清理统计信息)
+    """
+    stats = {'original_count': len(df)}
+    
+    # 分离特征和目标变量
+    if target_column is not None:
+        if isinstance(target_column, str):
+            X = df.drop(columns=[target_column])
+            y = df[target_column]
+        else:
+            X = df.drop(df.columns[target_column], axis=1)
+            y = df.iloc[:, target_column]
+    else:
+        X = df
+        y = None
+    
+    # 移除重复项
+    if remove_dups:
+        if y is not None:
+            combined = pd.concat([X, y], axis=1)
+            combined, stats['duplicates_removed'] = remove_duplicates(combined)
+            X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
+            y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
+        else:
+            X, stats['duplicates_removed'] = remove_duplicates(X)
+    
+    # 处理异常值
+    if handle_outliers:
+        if y is not None:
+            combined = pd.concat([X, y], axis=1)
+            combined, stats['outliers_removed'] = remove_outliers(combined, method=outlier_method, threshold=outlier_threshold)
+            X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
+            y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
+        else:
+            X, stats['outliers_removed'] = remove_outliers(X, method=outlier_method, threshold=outlier_threshold)
+    
+    # 标准化特征
+    if normalize:
+        scaler = StandardScaler()
+        X = pd.DataFrame(scaler.fit_transform(X), columns=X.columns, index=X.index)
+        stats['normalized'] = True
+    
+    stats['final_count'] = len(X)
+    return X, y, stats 

+ 12 - 1
api/app/model.py

@@ -8,9 +8,11 @@ from sklearn.metrics import r2_score
 from sklearn.model_selection import train_test_split, cross_val_score
 from sqlalchemy import text
 from xgboost import XGBRegressor
+import logging
 
 from .database_models import Models, Datasets
 from .config import Config
+from .data_cleaner import clean_dataset
 
 
 # 加载模型
@@ -79,13 +81,22 @@ def train_and_save_model(session, model_type, model_name, model_description, dat
             if dataset.empty:
                 raise ValueError(f"Dataset {dataset_id} is empty or not found.")
 
+        # 使用数据清理模块
         if data_type == 'reflux':
             X = dataset.iloc[:, 1:-1]
             y = dataset.iloc[:, -1]
+
+            # target_column = -1  # 假设目标变量在最后一列
+            # X, y, clean_stats = clean_dataset(dataset, target_column=target_column)
         elif data_type == 'reduce':
             X = dataset.iloc[:, 2:]
             y = dataset.iloc[:, 1]
-
+            # target_column = 1  # 假设目标变量在第二列
+            # X, y, clean_stats = clean_dataset(dataset, target_column=target_column)
+        
+        # 记录清理统计信息
+        # logging.info(f"数据清理统计: {clean_stats}")
+        
         # 训练模型
         model = train_model_by_type(X, y, model_type)
 

+ 45 - 3
api/app/routes.py

@@ -139,6 +139,38 @@ def upload_dataset():
         dynamic_table_class = create_dynamic_table(new_dataset.Dataset_ID, column_types)
         insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class)
 
+        # 去除上传数据集内部的重复项
+        original_count = len(dataset_df)
+        dataset_df = dataset_df.drop_duplicates()
+        duplicates_in_file = original_count - len(dataset_df)
+
+        # 检查与现有数据的重复
+        duplicates_with_existing = 0
+        if dataset_type in ['reduce', 'reflux']:
+            # 确定表名
+            table_name = 'current_reduce' if dataset_type == 'reduce' else 'current_reflux'
+            
+            # 从表加载现有数据
+            existing_data = pd.read_sql_table(table_name, session.bind)
+            if 'id' in existing_data.columns:
+                existing_data = existing_data.drop('id', axis=1)
+            
+            # 确定用于比较的列
+            compare_columns = [col for col in dataset_df.columns if col in existing_data.columns]
+            
+            # 计算重复行数
+            original_df_len = len(dataset_df)
+            
+            # 使用concat和drop_duplicates找出非重复行
+            all_data = pd.concat([existing_data[compare_columns], dataset_df[compare_columns]])
+            duplicates_mask = all_data.duplicated(keep='first')
+            duplicates_with_existing = sum(duplicates_mask[len(existing_data):])
+            
+            # 保留非重复行
+            dataset_df = dataset_df[~duplicates_mask[len(existing_data):].values]
+            
+            logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
+
         # 根据 dataset_type 决定插入到哪个已有表
         if dataset_type == 'reduce':
             insert_data_into_existing_table(session, dataset_df, CurrentReduce)
@@ -151,15 +183,25 @@ def upload_dataset():
         training_triggered, task_id = check_and_trigger_training(session, dataset_type, dataset_df)
 
         response_data = {
-            'message': f'Dataset {dataset_name} uploaded successfully!',
+            'message': f'数据集 {dataset_name} 上传成功!',
             'dataset_id': new_dataset.Dataset_ID,
             'filename': unique_filename,
-            'training_triggered': training_triggered
+            'training_triggered': training_triggered,
+            'data_stats': {
+                'original_count': original_count,
+                'duplicates_in_file': duplicates_in_file,
+                'duplicates_with_existing': duplicates_with_existing,
+                'final_count': len(dataset_df)
+            }
         }
         
         if training_triggered:
             response_data['task_id'] = task_id
-            response_data['message'] += ' Auto-training has been triggered.'
+            response_data['message'] += ' 自动训练已触发。'
+
+        # 添加去重信息到消息中
+        if duplicates_with_existing > 0:
+            response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
 
         return jsonify(response_data), 201