|
@@ -0,0 +1,116 @@
|
|
|
+"""
|
|
|
+数据清理模块,提供各种数据清理和预处理功能
|
|
|
+"""
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+from sklearn.preprocessing import StandardScaler
|
|
|
+import logging
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+def remove_duplicates(df):
|
|
|
+ """
|
|
|
+ 移除数据框中的重复行
|
|
|
+
|
|
|
+ Args:
|
|
|
+ df: 输入数据框
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (清理后的数据框, 移除的重复项数量)
|
|
|
+ """
|
|
|
+ original_count = len(df)
|
|
|
+ df_clean = df.drop_duplicates()
|
|
|
+ duplicates_removed = original_count - len(df_clean)
|
|
|
+ logger.info(f"移除了 {duplicates_removed} 个重复样本")
|
|
|
+ return df_clean, duplicates_removed
|
|
|
+
|
|
|
+def remove_outliers(df, method='iqr', threshold=1.5):
|
|
|
+ """
|
|
|
+ 使用指定方法检测和移除异常值
|
|
|
+
|
|
|
+ Args:
|
|
|
+ df: 输入数据框
|
|
|
+ method: 异常值检测方法 ('iqr', 'zscore')
|
|
|
+ threshold: 异常值判定阈值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (清理后的数据框, 移除的异常值数量)
|
|
|
+ """
|
|
|
+ original_count = len(df)
|
|
|
+
|
|
|
+ if method == 'iqr':
|
|
|
+ Q1 = df.quantile(0.25)
|
|
|
+ Q3 = df.quantile(0.75)
|
|
|
+ IQR = Q3 - Q1
|
|
|
+ outlier_mask = ~((df < (Q1 - threshold * IQR)) | (df > (Q3 + threshold * IQR))).any(axis=1)
|
|
|
+ df_clean = df[outlier_mask]
|
|
|
+
|
|
|
+ elif method == 'zscore':
|
|
|
+ from scipy import stats
|
|
|
+ z_scores = stats.zscore(df)
|
|
|
+ outlier_mask = ~(np.abs(z_scores) > threshold).any(axis=1)
|
|
|
+ df_clean = df[outlier_mask]
|
|
|
+
|
|
|
+ outliers_removed = original_count - len(df_clean)
|
|
|
+ logger.info(f"使用 {method} 方法移除了 {outliers_removed} 个异常值")
|
|
|
+ return df_clean, outliers_removed
|
|
|
+
|
|
|
+def clean_dataset(df, target_column=None, remove_dups=False, handle_outliers=False,
|
|
|
+ outlier_method='iqr', outlier_threshold=1.5, normalize=False):
|
|
|
+ """
|
|
|
+ 综合数据清理函数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ df: 输入数据框
|
|
|
+ target_column: 目标变量列名或索引
|
|
|
+ remove_dups: 是否移除重复项
|
|
|
+ handle_outliers: 是否处理异常值
|
|
|
+ outlier_method: 异常值检测方法
|
|
|
+ outlier_threshold: 异常值判定阈值
|
|
|
+ normalize: 是否标准化特征
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (特征数据框, 目标变量, 清理统计信息)
|
|
|
+ """
|
|
|
+ stats = {'original_count': len(df)}
|
|
|
+
|
|
|
+ # 分离特征和目标变量
|
|
|
+ if target_column is not None:
|
|
|
+ if isinstance(target_column, str):
|
|
|
+ X = df.drop(columns=[target_column])
|
|
|
+ y = df[target_column]
|
|
|
+ else:
|
|
|
+ X = df.drop(df.columns[target_column], axis=1)
|
|
|
+ y = df.iloc[:, target_column]
|
|
|
+ else:
|
|
|
+ X = df
|
|
|
+ y = None
|
|
|
+
|
|
|
+ # 移除重复项
|
|
|
+ if remove_dups:
|
|
|
+ if y is not None:
|
|
|
+ combined = pd.concat([X, y], axis=1)
|
|
|
+ combined, stats['duplicates_removed'] = remove_duplicates(combined)
|
|
|
+ X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
|
|
|
+ y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
|
|
|
+ else:
|
|
|
+ X, stats['duplicates_removed'] = remove_duplicates(X)
|
|
|
+
|
|
|
+ # 处理异常值
|
|
|
+ if handle_outliers:
|
|
|
+ if y is not None:
|
|
|
+ combined = pd.concat([X, y], axis=1)
|
|
|
+ combined, stats['outliers_removed'] = remove_outliers(combined, method=outlier_method, threshold=outlier_threshold)
|
|
|
+ X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
|
|
|
+ y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
|
|
|
+ else:
|
|
|
+ X, stats['outliers_removed'] = remove_outliers(X, method=outlier_method, threshold=outlier_threshold)
|
|
|
+
|
|
|
+ # 标准化特征
|
|
|
+ if normalize:
|
|
|
+ scaler = StandardScaler()
|
|
|
+ X = pd.DataFrame(scaler.fit_transform(X), columns=X.columns, index=X.index)
|
|
|
+ stats['normalized'] = True
|
|
|
+
|
|
|
+ stats['final_count'] = len(X)
|
|
|
+ return X, y, stats
|