|
@@ -163,9 +163,6 @@ class EffectiveCdPredictor:
|
|
|
|
|
|
# 加载数据
|
|
|
data = self.load_data()
|
|
|
- # 保存原始数据,包括经纬度
|
|
|
- longitude = data["longitude"].values
|
|
|
- latitude = data["latitude"].values
|
|
|
|
|
|
# 预处理数据
|
|
|
X_tensor, special_feature_index = self.preprocess_data(data)
|
|
@@ -180,14 +177,9 @@ class EffectiveCdPredictor:
|
|
|
).squeeze().numpy()
|
|
|
|
|
|
|
|
|
+ # 指数还原预测结果并按与作物Cd模型一致的格式输出为单列CSV
|
|
|
cd_predictions = np.exp(predictions)
|
|
|
-
|
|
|
- # 创建包含经纬度和预测结果的数据框
|
|
|
- predictions_df = pd.DataFrame({
|
|
|
- "longitude": longitude,
|
|
|
- "latitude": latitude,
|
|
|
- "cd_prediction": cd_predictions
|
|
|
- })
|
|
|
+ predictions_df = pd.DataFrame(cd_predictions)
|
|
|
output_path = os.path.join(
|
|
|
config.DATA_PATHS["predictions_dir"],
|
|
|
self.model_config["output_file"]
|