data_cleaner.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. 数据清理模块,提供各种数据清理和预处理功能
  3. """
  4. import pandas as pd
  5. import numpy as np
  6. from sklearn.preprocessing import StandardScaler
  7. import logging
  8. logger = logging.getLogger(__name__)
  9. def remove_duplicates(df):
  10. """
  11. 移除数据框中的重复行
  12. Args:
  13. df: 输入数据框
  14. Returns:
  15. tuple: (清理后的数据框, 移除的重复项数量)
  16. """
  17. original_count = len(df)
  18. df_clean = df.drop_duplicates()
  19. duplicates_removed = original_count - len(df_clean)
  20. logger.info(f"移除了 {duplicates_removed} 个重复样本")
  21. return df_clean, duplicates_removed
  22. def remove_outliers(df, method='iqr', threshold=1.5):
  23. """
  24. 使用指定方法检测和移除异常值
  25. Args:
  26. df: 输入数据框
  27. method: 异常值检测方法 ('iqr', 'zscore')
  28. threshold: 异常值判定阈值
  29. Returns:
  30. tuple: (清理后的数据框, 移除的异常值数量)
  31. """
  32. original_count = len(df)
  33. if method == 'iqr':
  34. Q1 = df.quantile(0.25)
  35. Q3 = df.quantile(0.75)
  36. IQR = Q3 - Q1
  37. outlier_mask = ~((df < (Q1 - threshold * IQR)) | (df > (Q3 + threshold * IQR))).any(axis=1)
  38. df_clean = df[outlier_mask]
  39. elif method == 'zscore':
  40. from scipy import stats
  41. z_scores = stats.zscore(df)
  42. outlier_mask = ~(np.abs(z_scores) > threshold).any(axis=1)
  43. df_clean = df[outlier_mask]
  44. outliers_removed = original_count - len(df_clean)
  45. logger.info(f"使用 {method} 方法移除了 {outliers_removed} 个异常值")
  46. return df_clean, outliers_removed
  47. def clean_dataset(df, target_column=None, remove_dups=False, handle_outliers=False,
  48. outlier_method='iqr', outlier_threshold=1.5, normalize=False):
  49. """
  50. 综合数据清理函数
  51. Args:
  52. df: 输入数据框
  53. target_column: 目标变量列名或索引
  54. remove_dups: 是否移除重复项
  55. handle_outliers: 是否处理异常值
  56. outlier_method: 异常值检测方法
  57. outlier_threshold: 异常值判定阈值
  58. normalize: 是否标准化特征
  59. Returns:
  60. tuple: (特征数据框, 目标变量, 清理统计信息)
  61. """
  62. stats = {'original_count': len(df)}
  63. # 分离特征和目标变量
  64. if target_column is not None:
  65. if isinstance(target_column, str):
  66. X = df.drop(columns=[target_column])
  67. y = df[target_column]
  68. else:
  69. X = df.drop(df.columns[target_column], axis=1)
  70. y = df.iloc[:, target_column]
  71. else:
  72. X = df
  73. y = None
  74. # 移除重复项
  75. if remove_dups:
  76. if y is not None:
  77. combined = pd.concat([X, y], axis=1)
  78. combined, stats['duplicates_removed'] = remove_duplicates(combined)
  79. X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
  80. y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
  81. else:
  82. X, stats['duplicates_removed'] = remove_duplicates(X)
  83. # 处理异常值
  84. if handle_outliers:
  85. if y is not None:
  86. combined = pd.concat([X, y], axis=1)
  87. combined, stats['outliers_removed'] = remove_outliers(combined, method=outlier_method, threshold=outlier_threshold)
  88. X = combined.iloc[:, :-1] if isinstance(target_column, int) else combined.drop(columns=[target_column])
  89. y = combined.iloc[:, -1] if isinstance(target_column, int) else combined[target_column]
  90. else:
  91. X, stats['outliers_removed'] = remove_outliers(X, method=outlier_method, threshold=outlier_threshold)
  92. # 标准化特征
  93. if normalize:
  94. scaler = StandardScaler()
  95. X = pd.DataFrame(scaler.fit_transform(X), columns=X.columns, index=X.index)
  96. stats['normalized'] = True
  97. stats['final_count'] = len(X)
  98. return X, y, stats