engine.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. """
  2. Cd预测引擎 v3.0
  3. @description: 完全自包含的预测引擎,不依赖外部集成系统
  4. @version: 3.0.0
  5. """
  6. import os
  7. import logging
  8. import tempfile
  9. import shutil
  10. from datetime import datetime
  11. from typing import Dict, Any, Optional, Tuple
  12. import pandas as pd
  13. import numpy as np
  14. from .predictors import CropCdPredictor, EffectiveCdPredictor, DataProcessor
  15. from .config import get_raster_config, get_template_tif_path, VISUALIZATION_CONFIG, ensure_directories
  16. from ...utils.mapping_utils import dataframe_to_raster_workflow, MappingUtils
  17. class CdPredictionEngine:
  18. """
  19. Cd预测引擎 v3.0 - 完全自包含版本
  20. """
  21. def __init__(self, output_base_dir: str):
  22. """
  23. 初始化预测引擎
  24. @param {str} output_base_dir - 输出基础目录
  25. """
  26. self.output_base_dir = output_base_dir
  27. self.logger = logging.getLogger(__name__)
  28. # 确保输出目录存在
  29. ensure_directories(output_base_dir)
  30. # 设置输出路径
  31. self.output_paths = {
  32. "figures": os.path.join(output_base_dir, "figures"),
  33. "raster": os.path.join(output_base_dir, "raster"),
  34. "data": os.path.join(output_base_dir, "data"),
  35. "temp": os.path.join(output_base_dir, "data", "temp"),
  36. "final": os.path.join(output_base_dir, "data", "final")
  37. }
  38. # 初始化预测器(懒加载)
  39. self._crop_predictor = None
  40. self._effective_predictor = None
  41. self._data_processor = None
  42. self._mapping_utils = None
  43. self.logger.info(f"Cd预测引擎v3.0初始化完成,输出目录: {output_base_dir}")
  44. @property
  45. def crop_predictor(self) -> CropCdPredictor:
  46. """获取作物Cd预测器(懒加载)"""
  47. if self._crop_predictor is None:
  48. self._crop_predictor = CropCdPredictor()
  49. return self._crop_predictor
  50. @property
  51. def effective_predictor(self) -> EffectiveCdPredictor:
  52. """获取有效态Cd预测器(懒加载)"""
  53. if self._effective_predictor is None:
  54. self._effective_predictor = EffectiveCdPredictor()
  55. return self._effective_predictor
  56. @property
  57. def data_processor(self) -> DataProcessor:
  58. """获取数据处理器(懒加载)"""
  59. if self._data_processor is None:
  60. self._data_processor = DataProcessor()
  61. return self._data_processor
  62. @property
  63. def mapping_utils(self) -> MappingUtils:
  64. """获取地图工具(懒加载)"""
  65. if self._mapping_utils is None:
  66. self._mapping_utils = MappingUtils()
  67. return self._mapping_utils
  68. def predict_crop_cd(self, environmental_data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
  69. """
  70. 执行作物Cd预测
  71. @param {pd.DataFrame} environmental_data - 环境因子数据
  72. @returns {Tuple[np.ndarray, pd.DataFrame]} 预测结果和验证信息
  73. """
  74. try:
  75. self.logger.info("开始作物Cd预测...")
  76. # 执行预测
  77. predictions = self.crop_predictor.predict(environmental_data)
  78. # 验证结果
  79. temp_df = pd.DataFrame({
  80. 'longitude': [0] * len(predictions), # 临时坐标
  81. 'latitude': [0] * len(predictions),
  82. 'Prediction': predictions
  83. })
  84. validation_result = self.data_processor.validate_final_data(temp_df)
  85. self.logger.info("作物Cd预测完成")
  86. return predictions, validation_result
  87. except Exception as e:
  88. self.logger.error(f"作物Cd预测失败: {str(e)}")
  89. raise
  90. def predict_effective_cd(self, environmental_data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
  91. """
  92. 执行有效态Cd预测
  93. @param {pd.DataFrame} environmental_data - 环境因子数据
  94. @returns {Tuple[np.ndarray, pd.DataFrame]} 预测结果和验证信息
  95. """
  96. try:
  97. self.logger.info("开始有效态Cd预测...")
  98. # 执行预测
  99. predictions = self.effective_predictor.predict(environmental_data)
  100. # 验证结果
  101. temp_df = pd.DataFrame({
  102. 'longitude': [0] * len(predictions), # 临时坐标
  103. 'latitude': [0] * len(predictions),
  104. 'Prediction': predictions
  105. })
  106. validation_result = self.data_processor.validate_final_data(temp_df)
  107. self.logger.info("有效态Cd预测完成")
  108. return predictions, validation_result
  109. except Exception as e:
  110. self.logger.error(f"有效态Cd预测失败: {str(e)}")
  111. raise
  112. def create_final_dataset(self, coordinates: pd.DataFrame, predictions: np.ndarray,
  113. model_type: str) -> str:
  114. """
  115. 创建最终数据集
  116. @param {pd.DataFrame} coordinates - 坐标数据
  117. @param {np.ndarray} predictions - 预测结果
  118. @param {str} model_type - 模型类型
  119. @returns {str} 最终数据文件路径
  120. """
  121. try:
  122. # 合并数据
  123. final_data = self.data_processor.combine_predictions_with_coordinates(
  124. coordinates, predictions
  125. )
  126. # 保存最终数据
  127. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  128. filename = f"Final_predictions_{model_type}_{timestamp}.csv"
  129. final_path = os.path.join(self.output_paths["final"], filename)
  130. final_data.to_csv(final_path, index=False, encoding='utf-8-sig')
  131. self.logger.info(f"最终数据集已保存: {final_path}")
  132. return final_path
  133. except Exception as e:
  134. self.logger.error(f"创建最终数据集失败: {str(e)}")
  135. raise
  136. def create_visualization(self, final_data_df: pd.DataFrame, model_type: str,
  137. county_name: str, boundary_gdf=None,
  138. raster_config_override: Optional[Dict[str, Any]] = None,
  139. save_raster: bool = False) -> Dict[str, str]:
  140. """
  141. 创建可视化图表
  142. @param {pd.DataFrame} final_data_df - 最终数据DataFrame
  143. @param {str} model_type - 模型类型
  144. @param {str} county_name - 县市名称
  145. @param boundary_gdf - 边界GeoDataFrame
  146. @param {Optional[Dict[str, Any]]} raster_config_override - 栅格配置覆盖
  147. @param {bool} save_raster - 是否保存栅格文件,默认False
  148. @returns {Dict[str, str]} 生成的文件路径
  149. """
  150. try:
  151. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  152. # 获取栅格配置
  153. raster_config = get_raster_config(raster_config_override)
  154. # 生成栅格数据(可选择是否保存到文件)
  155. if save_raster:
  156. self.logger.info("开始生成栅格文件(保存到磁盘)...")
  157. workflow_result = dataframe_to_raster_workflow(
  158. df=final_data_df,
  159. template_tif=get_template_tif_path(),
  160. output_dir=self.output_paths["raster"],
  161. boundary_gdf=boundary_gdf,
  162. resolution_factor=raster_config['resolution_factor'],
  163. interpolation_method=raster_config['interpolation_method'],
  164. field_name=raster_config['field_name'],
  165. lon_col=raster_config['coordinate_columns']['longitude'],
  166. lat_col=raster_config['coordinate_columns']['latitude'],
  167. value_col=raster_config['coordinate_columns']['value'],
  168. enable_interpolation=raster_config['enable_interpolation']
  169. )
  170. raster_path = workflow_result['raster']
  171. # 重命名栅格文件
  172. final_raster_name = f"prediction_{model_type}_{county_name}_{timestamp}.tif"
  173. final_raster_path = os.path.join(self.output_paths["raster"], final_raster_name)
  174. if raster_path != final_raster_path:
  175. shutil.move(raster_path, final_raster_path)
  176. else:
  177. self.logger.info("生成临时栅格数据(仅用于可视化,不保存文件)...")
  178. # 使用临时目录生成栅格,用于可视化后删除
  179. import tempfile
  180. temp_dir = tempfile.mkdtemp()
  181. try:
  182. workflow_result = dataframe_to_raster_workflow(
  183. df=final_data_df,
  184. template_tif=get_template_tif_path(),
  185. output_dir=temp_dir,
  186. boundary_gdf=boundary_gdf,
  187. resolution_factor=raster_config['resolution_factor'],
  188. interpolation_method=raster_config['interpolation_method'],
  189. field_name=raster_config['field_name'],
  190. lon_col=raster_config['coordinate_columns']['longitude'],
  191. lat_col=raster_config['coordinate_columns']['latitude'],
  192. value_col=raster_config['coordinate_columns']['value'],
  193. enable_interpolation=raster_config['enable_interpolation']
  194. )
  195. final_raster_path = workflow_result['raster'] # 临时栅格文件路径
  196. except Exception as e:
  197. # 清理临时目录
  198. shutil.rmtree(temp_dir, ignore_errors=True)
  199. raise e
  200. # 生成地图可视化
  201. self.logger.info("开始生成地图可视化...")
  202. map_title = self._get_map_title(model_type)
  203. map_filename = f"prediction_map_{model_type}_{county_name}_{timestamp}" # 不包含扩展名
  204. map_path = os.path.join(self.output_paths["figures"], map_filename)
  205. map_result = self.mapping_utils.create_raster_map(
  206. shp_path=None, # 不使用shapefile路径
  207. tif_path=final_raster_path,
  208. output_path=map_path,
  209. title=map_title,
  210. colormap=VISUALIZATION_CONFIG['default_colormap'],
  211. figsize=VISUALIZATION_CONFIG['figure_size'], # 使用figsize而不是output_size
  212. dpi=VISUALIZATION_CONFIG['dpi'],
  213. resolution_factor=1.0,
  214. enable_interpolation=False,
  215. interpolation_method='nearest',
  216. boundary_gdf=boundary_gdf # 使用GeoDataFrame边界
  217. )
  218. # 生成直方图
  219. self.logger.info("开始生成直方图...")
  220. hist_title, hist_xlabel = self._get_histogram_labels(model_type)
  221. hist_filename = f"prediction_histogram_{model_type}_{county_name}_{timestamp}.jpg"
  222. hist_path = os.path.join(self.output_paths["figures"], hist_filename)
  223. hist_result = self.mapping_utils.create_histogram(
  224. file_path=final_raster_path,
  225. save_path=hist_path,
  226. figsize=(6, 6),
  227. xlabel=hist_xlabel,
  228. ylabel='Frequency',
  229. title=hist_title,
  230. dpi=VISUALIZATION_CONFIG['dpi']
  231. )
  232. # 清理临时栅格文件(如果不保存栅格)
  233. if not save_raster and 'temp_dir' in locals():
  234. try:
  235. shutil.rmtree(temp_dir, ignore_errors=True)
  236. self.logger.info("临时栅格文件已清理")
  237. final_raster_path = None # 不返回栅格路径
  238. except Exception as cleanup_err:
  239. self.logger.warning(f"清理临时文件失败: {str(cleanup_err)}")
  240. result = {
  241. 'raster': final_raster_path if save_raster else None,
  242. 'map': map_result,
  243. 'histogram': hist_result
  244. }
  245. self.logger.info("可视化创建完成")
  246. return result
  247. except Exception as e:
  248. self.logger.error(f"创建可视化失败: {str(e)}")
  249. raise
  250. def predict_and_visualize(self, input_data: pd.DataFrame, model_type: str,
  251. county_name: str, boundary_gdf=None,
  252. raster_config_override: Optional[Dict[str, Any]] = None,
  253. save_raster: bool = False) -> Dict[str, Any]:
  254. """
  255. 完整的预测和可视化流程
  256. @param {pd.DataFrame} input_data - 输入数据(前两列为经纬度,后续列为环境因子)
  257. @param {str} model_type - 模型类型 ("crop_cd" 或 "effective_cd")
  258. @param {str} county_name - 县市名称
  259. @param boundary_gdf - 边界GeoDataFrame(可选)
  260. @param {Optional[Dict[str, Any]]} raster_config_override - 栅格配置覆盖
  261. @param {bool} save_raster - 是否保存栅格文件,默认False(仅生成地图和直方图)
  262. @returns {Dict[str, Any]} 完整结果
  263. """
  264. try:
  265. self.logger.info(f"开始{model_type}模型的完整预测流程(使用统一绘图接口)...")
  266. # 分离坐标和环境因子数据
  267. coordinates = input_data.iloc[:, :2].copy()
  268. coordinates.columns = ['longitude', 'latitude']
  269. environmental_data = input_data.iloc[:, 2:].copy()
  270. # 执行预测
  271. if model_type == "crop_cd":
  272. predictions, validation = self.predict_crop_cd(environmental_data)
  273. elif model_type == "effective_cd":
  274. predictions, validation = self.predict_effective_cd(environmental_data)
  275. else:
  276. raise ValueError(f"不支持的模型类型: {model_type}")
  277. # 合并坐标和预测结果为最终数据DataFrame
  278. final_data_df = self.data_processor.combine_predictions_with_coordinates(
  279. coordinates, predictions
  280. )
  281. # 保存最终数据文件(可选,为了兼容性)
  282. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  283. filename = f"Final_predictions_{model_type}_{timestamp}.csv"
  284. final_data_file = os.path.join(self.output_paths["final"], filename)
  285. final_data_df.to_csv(final_data_file, index=False, encoding='utf-8-sig')
  286. # 创建可视化 - 直接使用DataFrame,无需临时文件
  287. visualization_result = self.create_visualization(
  288. final_data_df, model_type, county_name, boundary_gdf, raster_config_override, save_raster
  289. )
  290. # 返回完整结果
  291. result = {
  292. 'model_type': model_type,
  293. 'county_name': county_name,
  294. 'final_data_file': final_data_file,
  295. 'final_data_df': final_data_df, # 新增:返回DataFrame
  296. 'raster_path': visualization_result['raster'],
  297. 'map_path': visualization_result['map'],
  298. 'histogram_path': visualization_result['histogram'],
  299. 'validation': validation,
  300. 'timestamp': datetime.now().isoformat()
  301. }
  302. self.logger.info(f"{model_type}模型完整预测流程完成")
  303. return result
  304. except Exception as e:
  305. self.logger.error(f"{model_type}模型完整预测流程失败: {str(e)}")
  306. raise
  307. def _get_map_title(self, model_type: str) -> str:
  308. """获取地图标题"""
  309. titles = {
  310. "crop_cd": "Crop Cd Prediction",
  311. "effective_cd": "Effective Cd Prediction"
  312. }
  313. return titles.get(model_type, f"{model_type} Prediction")
  314. def _get_histogram_labels(self, model_type: str) -> Tuple[str, str]:
  315. """获取直方图标签"""
  316. labels = {
  317. "crop_cd": ("Crop Cd Prediction Frequency", "Crop Cd Content (mg/kg)"),
  318. "effective_cd": ("Effective Cd Prediction Frequency", "Effective Cd Content (mg/kg)")
  319. }
  320. return labels.get(model_type, (f"{model_type} Prediction Frequency", f"{model_type} Content"))
  321. def cleanup_temp_files(self):
  322. """清理临时文件"""
  323. try:
  324. temp_dir = self.output_paths["temp"]
  325. if os.path.exists(temp_dir):
  326. for file in os.listdir(temp_dir):
  327. file_path = os.path.join(temp_dir, file)
  328. if os.path.isfile(file_path):
  329. os.remove(file_path)
  330. self.logger.debug(f"已删除临时文件: {file}")
  331. self.logger.info("临时文件清理完成")
  332. except Exception as e:
  333. self.logger.warning(f"清理临时文件失败: {str(e)}")
  334. def get_model_info(self) -> Dict[str, Any]:
  335. """
  336. 获取模型信息
  337. @returns {Dict[str, Any]} 模型信息
  338. """
  339. from .config import validate_model_files
  340. return {
  341. "version": "3.0.0",
  342. "output_base_dir": self.output_base_dir,
  343. "output_paths": self.output_paths,
  344. "crop_cd_files": validate_model_files("crop_cd"),
  345. "effective_cd_files": validate_model_files("effective_cd"),
  346. "template_tif": get_template_tif_path(),
  347. "template_tif_exists": os.path.exists(get_template_tif_path())
  348. }