""" 基于数据库的Cd预测服务类 @description: 从数据库表中读取数据进行预测并更新结果表 @author: AcidMap Team """ import os import logging import asyncio from datetime import datetime from typing import Dict, Any, Optional, List, Tuple import pandas as pd import numpy as np from sqlalchemy.orm import Session from sqlalchemy import and_ from ..database import SessionLocal from ..models.farmland import FarmlandData from ..models.EffCd_input import EffCdInputData from ..models.EffCd_output import EffCdOutputData from ..models.CropCd_input import CropCdInputData from ..models.CropCd_output import CropCdOutputData from .cd_prediction_service_v3 import CdPredictionServiceV3 from .admin_boundary_service import get_boundary_geojson_by_name from ..log.logger import get_logger import tempfile import json class CdPredictionDatabaseService: """ 基于数据库的Cd预测服务类 从数据库表读取输入数据,执行预测,并将结果保存回数据库 """ def __init__(self): """初始化数据库预测服务""" self.logger = get_logger(__name__) self.prediction_service = CdPredictionServiceV3() self.logger.info("数据库Cd预测服务初始化完成") def _get_database_session(self) -> Session: """获取数据库会话""" return SessionLocal() def _query_effective_cd_input_data(self, db: Session) -> pd.DataFrame: """ 从EffCd_input_data和Farmland_data表查询输入数据 @param db: 数据库会话 @returns: 包含经纬度和环境因子的DataFrame """ try: # 构建查询,联接EffCd_input_data和Farmland_data表 query = db.query( FarmlandData.lon, FarmlandData.lan, EffCdInputData.oc_fe_0_30, EffCdInputData.silt_content, EffCdInputData.sand_content, EffCdInputData.gravel_content, EffCdInputData.available_potassium, EffCdInputData.available_phosphorus, EffCdInputData.electrical_conductivity, EffCdInputData.slow_available_potassium, EffCdInputData.total_aluminum, EffCdInputData.total_calcium, EffCdInputData.total_cadmium, # 新增 EffCdInputData.soluble_salts, # 新增 EffCdInputData.exchangeable_acidity, # 新增 EffCdInputData.total_iron, EffCdInputData.total_potassium, # 新增 EffCdInputData.total_magnesium, EffCdInputData.total_manganese, EffCdInputData.total_nitrogen, EffCdInputData.total_phosphorus, # 新增 EffCdInputData.total_sulfur, EffCdInputData.cd_solution, EffCdInputData.farmland_id, EffCdInputData.sample_id ).join( FarmlandData, and_( EffCdInputData.farmland_id == FarmlandData.farmland_id, EffCdInputData.sample_id == FarmlandData.sample_id ) ) # 执行查询并转换为DataFrame results = query.all() if not results: raise ValueError("未找到符合条件的EffCd输入数据") # 转换为DataFrame df = pd.DataFrame([ { 'lon': row.lon, 'lan': row.lan, 'OC-Fe_0-30': row.oc_fe_0_30, '002_0002IDW': row.silt_content, '02_002IDW': row.sand_content, '2_02IDW': row.gravel_content, 'AvaK_IDW': row.available_potassium, 'AvaP_IDW': row.available_phosphorus, 'EC_IDW': row.electrical_conductivity, 'SAvaK_IDW': row.slow_available_potassium, 'TAl_IDW': row.total_aluminum, 'TCa_IDW': row.total_calcium, 'TCd_IDW': row.total_cadmium, # 新增 'TEB_IDW': row.soluble_salts, # 新增 'TExH_IDW': row.exchangeable_acidity, # 新增 'TFe_IDW': row.total_iron, 'TK_IDW': row.total_potassium, # 新增 'TMg_IDW': row.total_magnesium, 'TMn_IDW': row.total_manganese, 'TN_IDW': row.total_nitrogen, 'TP_IDW': row.total_phosphorus, # 新增 'TS_IDW': row.total_sulfur, 'Cdsolution': row.cd_solution, # 修正列名 'farmland_id': row.farmland_id, 'sample_id': row.sample_id } for row in results ]) self.logger.info(f"成功查询到{len(df)}条EffCd输入数据记录") return df except Exception as e: self.logger.error(f"查询EffCd输入数据失败: {str(e)}") raise def _query_crop_cd_input_data(self, db: Session) -> pd.DataFrame: """ 从CropCd_input_data和Farmland_data表查询输入数据 @param db: 数据库会话 @returns: 包含经纬度和环境因子的DataFrame """ try: # 构建查询,联接CropCd_input_data和Farmland_data表 query = db.query( FarmlandData.lon, FarmlandData.lan, CropCdInputData.silt_content, # 002_0002IDW CropCdInputData.sand_content, # 02_002IDW CropCdInputData.gravel_content, # 2_02IDW CropCdInputData.available_phosphorus, # AvaP_IDW (注意:这里对应CSV中的AvaP) CropCdInputData.available_potassium, # AvaK_IDW (注意:这里对应CSV中的AvaK) CropCdInputData.slow_available_potassium, # SAvaK_IDW CropCdInputData.total_aluminum, # TAl_IDW CropCdInputData.total_calcium, # TCa_IDW CropCdInputData.total_iron, # TFe_IDW CropCdInputData.total_magnesium, # TMg_IDW CropCdInputData.total_manganese, # TMn_IDW CropCdInputData.total_nitrogen, # TN_IDW CropCdInputData.total_sulfur, # TS_IDW CropCdInputData.ln_cd_solution, # solution CropCdInputData.farmland_id, CropCdInputData.sample_id ).join( FarmlandData, and_( CropCdInputData.farmland_id == FarmlandData.farmland_id, CropCdInputData.sample_id == FarmlandData.sample_id ) ) # 执行查询并转换为DataFrame results = query.all() if not results: raise ValueError("未找到符合条件的CropCd输入数据") # 转换为DataFrame df = pd.DataFrame([ { 'lon': row.lon, 'lan': row.lan, '002_0002IDW': row.silt_content, '02_002IDW': row.sand_content, '2_02IDW': row.gravel_content, 'AvaP': row.available_phosphorus, # 注意:CSV中的AvaP对应数据库的AvaP_IDW字段 'AvaK_IDW': row.available_potassium, 'SAvaK_IDW': row.slow_available_potassium, 'TAl_IDW': row.total_aluminum, 'TCa_IDW': row.total_calcium, 'TFe_IDW': row.total_iron, 'TMg_IDW': row.total_magnesium, 'TMn_IDW': row.total_manganese, 'TN_IDW': row.total_nitrogen, 'TS_IDW': row.total_sulfur, 'solution': row.ln_cd_solution, 'farmland_id': row.farmland_id, 'sample_id': row.sample_id } for row in results ]) self.logger.info(f"成功查询到{len(df)}条CropCd输入数据记录") return df except Exception as e: self.logger.error(f"查询CropCd输入数据失败: {str(e)}") raise def _prepare_prediction_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]: """ 准备预测数据,分离坐标和环境因子 @param df: 包含所有数据的DataFrame @returns: (坐标DataFrame, 环境因子DataFrame) """ try: # 提取坐标数据 (前两列) coordinates_df = df[['lon', 'lan']].copy() coordinates_df.columns = ['longitude', 'latitude'] # 提取环境因子数据 (去除坐标和ID列) exclude_cols = ['lon', 'lan', 'farmland_id', 'sample_id'] environmental_cols = [col for col in df.columns if col not in exclude_cols] environmental_df = df[environmental_cols].copy() # 检查数据完整性 missing_coords = coordinates_df.isnull().sum().sum() missing_env = environmental_df.isnull().sum().sum() if missing_coords > 0: self.logger.warning(f"坐标数据中有{missing_coords}个缺失值") if missing_env > 0: self.logger.warning(f"环境因子数据中有{missing_env}个缺失值") self.logger.info(f"准备预测数据完成: 坐标{coordinates_df.shape}, 环境因子{environmental_df.shape}") return coordinates_df, environmental_df except Exception as e: self.logger.error(f"准备预测数据失败: {str(e)}") raise def _save_predictions_to_database(self, db: Session, df: pd.DataFrame, predictions: np.ndarray) -> int: """ 将预测结果批量保存到EffCd_output_data表(使用UPSERT操作) @param db: 数据库会话 @param df: 原始数据DataFrame(包含farmland_id和sample_id) @param predictions: 预测结果数组 @returns: 更新的记录数 """ try: self.logger.info(f"开始批量保存{len(predictions)}条预测结果...") # 首先删除所有现有记录(简单粗暴但高效) farmland_ids = df['farmland_id'].unique().tolist() delete_count = db.query(EffCdOutputData).filter( EffCdOutputData.farmland_id.in_(farmland_ids) ).delete(synchronize_session=False) if delete_count > 0: self.logger.info(f"清理了{delete_count}条旧记录") # 准备批量插入数据(向量化操作,避免逐行循环) batch_data = [ { 'farmland_id': int(farmland_id), 'sample_id': int(sample_id), 'ln_eff_cd': float(prediction) } for farmland_id, sample_id, prediction in zip( df['farmland_id'].values, df['sample_id'].values, predictions ) ] # 分批批量插入新记录(避免单次插入过多数据) batch_size = 5000 # 每批处理5000条记录 total_inserted = 0 for i in range(0, len(batch_data), batch_size): batch_chunk = batch_data[i:i + batch_size] db.bulk_insert_mappings(EffCdOutputData, batch_chunk) total_inserted += len(batch_chunk) self.logger.info(f"已插入 {total_inserted}/{len(batch_data)} 条记录") # 提交事务 db.commit() self.logger.info(f"成功批量保存{total_inserted}条预测结果到EffCd_output_data表") return total_inserted except Exception as e: db.rollback() self.logger.error(f"批量保存预测结果到数据库失败: {str(e)}") raise def _save_crop_cd_predictions_to_database(self, db: Session, df: pd.DataFrame, predictions: np.ndarray) -> int: """ 将作物Cd预测结果批量保存到CropCd_output_data表 @param db: 数据库会话 @param df: 原始数据DataFrame(包含farmland_id和sample_id) @param predictions: 预测结果数组 @returns: 更新的记录数 """ try: self.logger.info(f"开始批量保存{len(predictions)}条作物Cd预测结果...") # 首先删除所有现有记录(简单粗暴但高效) farmland_ids = df['farmland_id'].unique().tolist() delete_count = db.query(CropCdOutputData).filter( CropCdOutputData.farmland_id.in_(farmland_ids) ).delete(synchronize_session=False) if delete_count > 0: self.logger.info(f"清理了{delete_count}条旧的作物Cd记录") # 准备批量插入数据(向量化操作,避免逐行循环) batch_data = [ { 'farmland_id': int(farmland_id), 'sample_id': int(sample_id), 'ln_crop_cd': float(prediction) } for farmland_id, sample_id, prediction in zip( df['farmland_id'].values, df['sample_id'].values, predictions ) ] # 分批批量插入新记录(避免单次插入过多数据) batch_size = 5000 # 每批处理5000条记录 total_inserted = 0 for i in range(0, len(batch_data), batch_size): batch_chunk = batch_data[i:i + batch_size] db.bulk_insert_mappings(CropCdOutputData, batch_chunk) total_inserted += len(batch_chunk) self.logger.info(f"已插入作物Cd预测结果 {total_inserted}/{len(batch_data)} 条记录") # 提交事务 db.commit() self.logger.info(f"成功批量保存{total_inserted}条作物Cd预测结果到CropCd_output_data表") return total_inserted except Exception as e: db.rollback() self.logger.error(f"批量保存作物Cd预测结果到数据库失败: {str(e)}") raise async def generate_effective_cd_prediction_from_database( self, area: str, level: str, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 基于数据库数据生成有效态Cd预测 @param area: 地区名称 @param level: 行政级别 @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果信息 """ db = None try: self.logger.info(f"开始基于数据库数据生成有效态Cd预测: {area} ({level})") # 获取数据库会话 db = self._get_database_session() # 查询输入数据 input_df = self._query_effective_cd_input_data(db) if len(input_df) == 0: raise ValueError("未找到有效的输入数据") # 准备预测数据 coordinates_df, environmental_df = self._prepare_prediction_data(input_df) # 合并坐标和环境因子用于预测服务 prediction_input_df = pd.concat([coordinates_df, environmental_df], axis=1) # 保存临时数据文件用于预测服务 temp_file_path = self.prediction_service.save_temp_data(prediction_input_df, area) # 获取边界数据 boundary_gdf = self._get_boundary_geojson(area, level) # 直接使用预测引擎进行预测和可视化 prediction_result = await self._run_effective_cd_prediction_with_boundary( prediction_input_df, area, boundary_gdf, raster_config_override ) # 从预测结果中提取预测值 # 可以直接从返回的DataFrame中获取预测结果 final_data_df = prediction_result.get('final_data_df') if final_data_df is not None and 'Prediction' in final_data_df.columns: predictions = final_data_df['Prediction'].values else: # 如果没有直接的DataFrame,尝试从文件读取 final_data_file = prediction_result.get('final_data_file') if final_data_file and os.path.exists(final_data_file): final_df = pd.read_csv(final_data_file) predictions = final_df['Prediction'].values else: raise ValueError("无法获取预测结果数据") # 保存预测结果到数据库 updated_count = self._save_predictions_to_database(db, input_df, predictions) result = { 'success': True, 'area': area, 'level': level, 'processed_records': len(input_df), 'updated_records': updated_count, 'map_path': prediction_result.get('map_path'), 'histogram_path': prediction_result.get('histogram_path'), 'raster_path': prediction_result.get('raster_path'), 'timestamp': prediction_result.get('timestamp'), 'validation': prediction_result.get('validation', {}) } self.logger.info(f"基于数据库数据的有效态Cd预测完成: {area} ({level}), 处理{len(input_df)}条记录") return result except Exception as e: self.logger.error(f"基于数据库数据生成有效态Cd预测失败: {str(e)}") raise finally: if db: db.close() def get_effective_cd_results_from_database( self, limit: Optional[int] = None ) -> pd.DataFrame: """ 从数据库获取有效态Cd预测结果 @param limit: 可选的结果数量限制 @returns: 包含预测结果的DataFrame """ db = None try: db = self._get_database_session() # 构建查询,联接结果表和农田数据表获取坐标 query = db.query( EffCdOutputData.farmland_id, EffCdOutputData.sample_id, EffCdOutputData.ln_eff_cd, FarmlandData.lon, FarmlandData.lan ).join( FarmlandData, and_( EffCdOutputData.farmland_id == FarmlandData.farmland_id, EffCdOutputData.sample_id == FarmlandData.sample_id ) ) # 添加数量限制 if limit: query = query.limit(limit) # 执行查询 results = query.all() if not results: return pd.DataFrame() # 转换为DataFrame df = pd.DataFrame([ { 'farmland_id': row.farmland_id, 'sample_id': row.sample_id, 'longitude': row.lon, 'latitude': row.lan, 'LnEffCd': row.ln_eff_cd, 'EffCd': np.exp(row.ln_eff_cd) # 转换为实际有效态镉浓度 } for row in results ]) self.logger.info(f"成功查询到{len(df)}条有效态Cd预测结果") return df except Exception as e: self.logger.error(f"查询有效态Cd预测结果失败: {str(e)}") raise finally: if db: db.close() async def generate_crop_cd_prediction_from_database( self, area: str, level: str, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 基于数据库数据生成作物Cd预测 @param area: 地区名称 @param level: 行政级别 @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果信息 """ db = None try: self.logger.info(f"开始基于数据库数据生成作物Cd预测: {area} ({level})") # 获取数据库会话 db = self._get_database_session() # 查询输入数据 input_df = self._query_crop_cd_input_data(db) if len(input_df) == 0: raise ValueError("未找到有效的作物Cd输入数据") # 准备预测数据 coordinates_df, environmental_df = self._prepare_prediction_data(input_df) # 合并坐标和环境因子用于预测服务 prediction_input_df = pd.concat([coordinates_df, environmental_df], axis=1) # 保存临时数据文件用于预测服务 temp_file_path = self.prediction_service.save_temp_data(prediction_input_df, area) # 获取边界数据 boundary_gdf = self._get_boundary_geojson(area, level) # 直接使用预测引擎进行预测和可视化 prediction_result = await self._run_crop_cd_prediction_with_boundary( prediction_input_df, area, boundary_gdf, raster_config_override ) # 从预测结果中提取预测值 # 可以直接从返回的DataFrame中获取预测结果 final_data_df = prediction_result.get('final_data_df') if final_data_df is not None and 'Prediction' in final_data_df.columns: predictions = final_data_df['Prediction'].values else: # 如果没有直接的DataFrame,尝试从文件读取 final_data_file = prediction_result.get('final_data_file') if final_data_file and os.path.exists(final_data_file): final_df = pd.read_csv(final_data_file) predictions = final_df['Prediction'].values else: raise ValueError("无法获取作物Cd预测结果数据") # 保存预测结果到数据库 updated_count = self._save_crop_cd_predictions_to_database(db, input_df, predictions) result = { 'success': True, 'area': area, 'level': level, 'processed_records': len(input_df), 'updated_records': updated_count, 'map_path': prediction_result.get('map_path'), 'histogram_path': prediction_result.get('histogram_path'), 'raster_path': prediction_result.get('raster_path'), 'timestamp': prediction_result.get('timestamp'), 'validation': prediction_result.get('validation', {}) } self.logger.info(f"基于数据库数据的作物Cd预测完成: {area} ({level}), 处理{len(input_df)}条记录") return result except Exception as e: self.logger.error(f"基于数据库数据生成作物Cd预测失败: {str(e)}") raise finally: if db: db.close() async def _run_crop_cd_prediction_with_boundary( self, input_data: pd.DataFrame, area: str, boundary_gdf, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 使用边界数据执行作物Cd预测 @param input_data: 输入数据DataFrame @param area: 地区名称 @param boundary_gdf: 边界GeoDataFrame @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果 """ try: # 在线程池中运行预测 loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, self._run_crop_cd_prediction_sync, input_data, area, boundary_gdf, raster_config_override ) return result except Exception as e: self.logger.error(f"执行作物Cd预测失败: {str(e)}") raise def _run_crop_cd_prediction_sync( self, input_data: pd.DataFrame, area: str, boundary_gdf, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 同步执行作物Cd预测(用于线程池) @param input_data: 输入数据DataFrame @param area: 地区名称 @param boundary_gdf: 边界GeoDataFrame @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果 """ try: # 使用预测引擎的predict_and_visualize方法 result = self.prediction_service.engine.predict_and_visualize( input_data=input_data, model_type="crop_cd", # 使用作物Cd模型 county_name=area, # 使用area作为county_name boundary_gdf=boundary_gdf, raster_config_override=raster_config_override, save_raster=False # 不保存栅格文件,节省存储空间 ) return result except Exception as e: self.logger.error(f"同步作物Cd预测执行失败: {str(e)}") raise def _get_boundary_geojson(self, area: str, level: str) -> Optional[object]: """ 获取指定区域的边界GeoDataFrame @param area: 地区名称 @param level: 行政级别 @returns: GeoDataFrame对象或None """ try: db = self._get_database_session() feature = get_boundary_geojson_by_name(db, area, level) if feature: # 将feature转换为GeoDataFrame import geopandas as gpd # 创建临时GeoJSON文件 tmp_dir = tempfile.mkdtemp() tmp_geojson = os.path.join(tmp_dir, "boundary.geojson") fc = {"type": "FeatureCollection", "features": [feature]} with open(tmp_geojson, 'w', encoding='utf-8') as f: json.dump(fc, f, ensure_ascii=False) # 读取为GeoDataFrame boundary_gdf = gpd.read_file(tmp_geojson) # 清理临时文件 import shutil shutil.rmtree(tmp_dir, ignore_errors=True) self.logger.info(f"成功获取边界数据: {area} ({level})") return boundary_gdf return None except Exception as e: self.logger.warning(f"获取边界数据失败: {str(e)}") return None finally: try: db.close() except Exception: pass async def _run_effective_cd_prediction_with_boundary( self, input_data: pd.DataFrame, area: str, boundary_gdf, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 使用边界数据执行有效态Cd预测 @param input_data: 输入数据DataFrame @param area: 地区名称 @param boundary_gdf: 边界GeoDataFrame @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果 """ try: # 在线程池中运行预测 loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, self._run_prediction_sync, input_data, area, boundary_gdf, raster_config_override ) return result except Exception as e: self.logger.error(f"执行有效态Cd预测失败: {str(e)}") raise def _run_prediction_sync( self, input_data: pd.DataFrame, area: str, boundary_gdf, raster_config_override: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 同步执行预测(用于线程池) @param input_data: 输入数据DataFrame @param area: 地区名称 @param boundary_gdf: 边界GeoDataFrame @param raster_config_override: 栅格配置覆盖参数 @returns: 预测结果 """ try: # 使用预测引擎的predict_and_visualize方法 result = self.prediction_service.engine.predict_and_visualize( input_data=input_data, model_type="effective_cd", county_name=area, # 使用area作为county_name boundary_gdf=boundary_gdf, raster_config_override=raster_config_override, save_raster=False # 不保存栅格文件,节省存储空间 ) return result except Exception as e: self.logger.error(f"同步预测执行失败: {str(e)}") raise