learning_rate.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from sklearn.model_selection import learning_curve
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.ensemble import RandomForestRegressor
  5. from xgboost import XGBRegressor as XGBR
  6. import pandas as pd
  7. from sklearn.model_selection import train_test_split
  8. # 定义数据集配置
  9. DATASET_CONFIGS = {
  10. 'soil_acid_9features': {
  11. 'file_path': 'model_optimize/data/data_filt.xlsx',
  12. 'x_columns': range(1, 10), # 9个特征(包含delta_ph)
  13. 'y_column': -1, # 105_day_ph
  14. 'feature_names': [
  15. 'organic_matter', # OM g/kg
  16. 'chloride', # CL g/kg
  17. 'cec', # CEC cmol/kg
  18. 'h_concentration', # H+ cmol/kg
  19. 'hn', # HN mg/kg
  20. 'al_concentration', # Al3+ cmol/kg
  21. 'free_alumina', # Free alumina g/kg
  22. 'free_iron', # Free iron oxides g/kg
  23. 'delta_ph' # ΔpH
  24. ],
  25. 'target_name': 'target_ph'
  26. },
  27. 'soil_acid_8features': {
  28. 'file_path': 'model_optimize/data/data_filt - 副本.xlsx',
  29. 'x_columns': range(1, 9), # 8个特征
  30. 'y_column': -2, # delta_ph
  31. 'feature_names': [
  32. 'organic_matter', # OM g/kg
  33. 'chloride', # CL g/kg
  34. 'cec', # CEC cmol/kg
  35. 'h_concentration', # H+ cmol/kg
  36. 'hn', # HN mg/kg
  37. 'al_concentration', # Al3+ cmol/kg
  38. 'free_alumina', # Free alumina g/kg
  39. 'free_iron', # Free iron oxides g/kg
  40. ],
  41. 'target_name': 'target_ph'
  42. },
  43. 'soil_acid_8features_original': {
  44. 'file_path': 'model_optimize/data/data_filt.xlsx',
  45. 'x_columns': range(1, 9), # 8个特征
  46. 'y_column': -2, # delta_ph
  47. 'feature_names': [
  48. 'organic_matter', # OM g/kg
  49. 'chloride', # CL g/kg
  50. 'cec', # CEC cmol/kg
  51. 'h_concentration', # H+ cmol/kg
  52. 'hn', # HN mg/kg
  53. 'al_concentration', # Al3+ cmol/kg
  54. 'free_alumina', # Free alumina g/kg
  55. 'free_iron', # Free iron oxides g/kg
  56. ],
  57. 'target_name': 'target_ph'
  58. },
  59. 'soil_acid_6features': {
  60. 'file_path': 'model_optimize/data/data_reflux2.xlsx',
  61. 'x_columns': range(0, 6), # 6个特征
  62. 'y_column': -1, # delta_ph
  63. 'feature_names': [
  64. 'organic_matter', # OM g/kg
  65. 'chloride', # CL g/kg
  66. 'cec', # CEC cmol/kg
  67. 'h_concentration', # H+ cmol/kg
  68. 'hn', # HN mg/kg
  69. 'al_concentration', # Al3+ cmol/kg
  70. ],
  71. 'target_name': 'delta_ph'
  72. },
  73. 'acidity_reduce': {
  74. 'file_path': 'model_optimize/data/Acidity_reduce.xlsx',
  75. 'x_columns': range(1, 6), # 5个特征
  76. 'y_column': 0, # 1/b
  77. 'feature_names': [
  78. 'pH',
  79. 'OM',
  80. 'CL',
  81. 'H',
  82. 'Al'
  83. ],
  84. 'target_name': 'target'
  85. },
  86. 'acidity_reduce_new': {
  87. 'file_path': 'model_optimize/data/Acidity_reduce_new.xlsx',
  88. 'x_columns': range(1, 6), # 5个特征
  89. 'y_column': 0, # 1/b
  90. 'feature_names': [
  91. 'pH',
  92. 'OM',
  93. 'CL',
  94. 'H',
  95. 'Al'
  96. ],
  97. 'target_name': 'target'
  98. }
  99. }
  100. def load_dataset(dataset_name):
  101. """
  102. 加载指定的数据集
  103. Args:
  104. dataset_name: 数据集配置名称
  105. Returns:
  106. x: 特征数据
  107. y: 目标数据
  108. """
  109. if dataset_name not in DATASET_CONFIGS:
  110. raise ValueError(f"未知的数据集名称: {dataset_name}")
  111. config = DATASET_CONFIGS[dataset_name]
  112. data = pd.read_excel(config['file_path'])
  113. x = data.iloc[:, config['x_columns']]
  114. y = data.iloc[:, config['y_column']]
  115. # 设置列名
  116. x.columns = config['feature_names']
  117. y.name = config['target_name']
  118. return x, y
  119. # 选择要使用的数据集
  120. # dataset_name = 'soil_acid_9features' # 土壤反酸数据:64个样本,9个特征(包含delta_ph),目标 105_day_ph
  121. # dataset_name = 'soil_acid_8features_original' # 土壤反酸数据:64个样本,8个特征,目标 delta_ph
  122. # dataset_name = 'soil_acid_8features' # 土壤反酸数据:60个样本(去除异常点),8个特征,目标 delta_ph
  123. # dataset_name = 'soil_acid_6features' # 土壤反酸数据:34个样本,6个特征,目标 delta_ph
  124. dataset_name = 'acidity_reduce' # 精准降酸数据:54个样本,5个特征,目标是1/b
  125. # dataset_name = 'acidity_reduce_new' # 精准降酸数据(数据更新):54个样本,5个特征,目标是1/b
  126. x, y = load_dataset(dataset_name)
  127. print("特征数据:")
  128. print(x)
  129. print("\n目标数据:")
  130. print(y)
  131. ## 数据集划分
  132. Xtrain, Xtest, Ytrain, Ytest = train_test_split(x, y, test_size=0.2, random_state=42)
  133. # 模型:使用 RandomForestRegressor 举例
  134. rfc = RandomForestRegressor(random_state=1)
  135. XGB = XGBR(random_state=1)
  136. # 计算学习曲线
  137. train_sizes, train_scores, test_scores = learning_curve(
  138. rfc, # 使用的模型
  139. Xtrain, # 训练特征
  140. Ytrain, # 训练目标
  141. cv=5, # 交叉验证折数
  142. n_jobs=-1, # 使用所有可用的CPU核心进行并行计算
  143. train_sizes=np.linspace(0.1, 1.0, 10) # 训练集大小,从10%到100%,共10个点
  144. )
  145. # 获取 test_scores(交叉验证测试集的得分)
  146. print("test_scores: \n", test_scores)
  147. print("train_scores: \n", train_scores)
  148. # 绘制学习曲线
  149. plt.figure(figsize=(8, 6))
  150. # 将训练样本数转换为百分比
  151. train_sizes_pct = train_sizes / len(Xtrain) * 100
  152. # 绘制训练误差和测试误差
  153. plt.plot(train_sizes_pct, np.mean(train_scores, axis=1), label="Training score", color="r")
  154. plt.plot(train_sizes_pct, np.mean(test_scores, axis=1), label="Cross-validation score", color="g")
  155. # 绘制图形的细节
  156. plt.title("Learning Curve (Random Forest Regressor)")
  157. plt.xlabel("Training Size (%)")
  158. plt.ylabel("Score (R²)")
  159. plt.legend(loc="best")
  160. plt.grid(True)
  161. # 设置x轴刻度为10的整数倍
  162. plt.xticks(np.arange(0, 101, 10))
  163. plt.show()