Sfoglia il codice sorgente

上传数据集时滤除测试集;使用全部数据作为训练集

drggboy 5 mesi fa
parent
commit
6a83718389
2 ha cambiato i file con 86 aggiunte e 3 eliminazioni
  1. 69 2
      api/app/model.py
  2. 17 1
      api/app/routes.py

+ 69 - 2
api/app/model.py

@@ -33,8 +33,72 @@ def predict(session, input_data: pd.DataFrame, model_id):
     predictions = ML_model.predict(input_data)
     return predictions.tolist()
 
+def check_dataset_overlap_with_test(dataset_df, data_type):
+    """
+    检查数据集是否与测试集有重叠
+    
+    Args:
+        dataset_df (DataFrame): 要检查的数据集
+        data_type (str): 数据集类型 ('reflux' 或 'reduce')
+        
+    Returns:
+        tuple: (重叠的行数, 重叠的行索引)
+    """
+    # 加载测试集
+    if data_type == 'reflux':
+        X_test = pd.read_csv('uploads/data/X_test_reflux.csv')
+        Y_test = pd.read_csv('uploads/data/Y_test_reflux.csv')
+    elif data_type == 'reduce':
+        X_test = pd.read_csv('uploads/data/X_test_reduce.csv')
+        Y_test = pd.read_csv('uploads/data/Y_test_reduce.csv')
+    else:
+        raise ValueError(f"不支持的数据类型: {data_type}")
+    
+    # 合并X_test和Y_test
+    if data_type == 'reflux':
+        test_df = pd.concat([X_test, Y_test], axis=1)
+    else:
+        test_df = pd.concat([X_test, Y_test], axis=1)
+    
+    # 确定用于比较的列
+    compare_columns = [col for col in dataset_df.columns if col in test_df.columns]
+    
+    if not compare_columns:
+        return 0, []
+    
+    # 查找重叠的行
+    merged = dataset_df[compare_columns].merge(test_df[compare_columns], how='inner', indicator=True)
+    overlapping_rows = merged[merged['_merge'] == 'both']
+    
+    # 获取重叠行在原始数据集中的索引
+    if not overlapping_rows.empty:
+        # 使用合并后的数据找回原始索引
+        overlap_indices = []
+        for _, row in overlapping_rows.iterrows():
+            # 创建一个布尔掩码,用于在原始数据集中查找匹配的行
+            mask = True
+            for col in compare_columns:
+                mask = mask & (dataset_df[col] == row[col])
+            
+            # 获取匹配行的索引
+            matching_indices = dataset_df[mask].index.tolist()
+            overlap_indices.extend(matching_indices)
+        
+        return len(set(overlap_indices)), list(set(overlap_indices))
+    
+    return 0, []
+
 # 计算模型评分
 def calculate_model_score(model_info):
+    """
+    计算模型评分
+    
+    Args:
+        model_info: 模型信息对象
+        
+    Returns:
+        float: 模型的R²评分
+    """
     # 加载模型
     with open(model_info.ModelFilePath, 'rb') as f:
         ML_model = pickle.load(f)
@@ -163,8 +227,11 @@ def data_type_table_mapping(data_type):
 
 def train_model_by_type(X, y, model_type):
     # 划分数据集
-    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
-
+    # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+    
+    # 使用全部数据作为训练集
+    X_train, y_train = X, y
+    
     if model_type == 'RandomForest':
         # 随机森林的参数优化
         return train_random_forest(X_train, y_train)

+ 17 - 1
api/app/routes.py

@@ -6,7 +6,7 @@ from flask import Blueprint, request, jsonify, current_app, send_file
 from werkzeug.security import check_password_hash, generate_password_hash
 from werkzeug.utils import secure_filename
 
-from .model import predict, train_and_save_model, calculate_model_score
+from .model import predict, train_and_save_model, calculate_model_score, check_dataset_overlap_with_test
 import pandas as pd
 from . import db  # 从 app 包导入 db 实例
 from sqlalchemy.engine.reflection import Inspector
@@ -171,6 +171,17 @@ def upload_dataset():
             
             logger.info(f"原始数据: {original_df_len}, 与现有数据重复: {duplicates_with_existing}, 保留: {len(dataset_df)}")
 
+        # 检查与测试集的重叠
+        test_overlap_count, test_overlap_indices = check_dataset_overlap_with_test(dataset_df, dataset_type)
+        
+        # 如果有与测试集重叠的数据,从数据集中移除
+        if test_overlap_count > 0:
+            # 创建一个布尔掩码,标记不在重叠索引中的行
+            mask = ~dataset_df.index.isin(test_overlap_indices)
+            # 应用掩码,只保留不重叠的行
+            dataset_df = dataset_df[mask]
+            logger.warning(f"移除了 {test_overlap_count} 行与测试集重叠的数据")
+
         # 根据 dataset_type 决定插入到哪个已有表
         if dataset_type == 'reduce':
             insert_data_into_existing_table(session, dataset_df, CurrentReduce)
@@ -191,6 +202,7 @@ def upload_dataset():
                 'original_count': original_count,
                 'duplicates_in_file': duplicates_in_file,
                 'duplicates_with_existing': duplicates_with_existing,
+                'test_overlap_count': test_overlap_count,
                 'final_count': len(dataset_df)
             }
         }
@@ -202,6 +214,10 @@ def upload_dataset():
         # 添加去重信息到消息中
         if duplicates_with_existing > 0:
             response_data['message'] += f' 已移除 {duplicates_with_existing} 个与现有数据重复的项。'
+            
+        # 添加测试集重叠信息到消息中
+        if test_overlap_count > 0:
+            response_data['message'] += f' 已移除 {test_overlap_count} 个与测试集重叠的项。'
 
         return jsonify(response_data), 201