""" Cd通量移除计算服务 @description: 提供籽粒移除和秸秆移除的Cd通量计算功能 @author: AcidMap Team @version: 1.0.0 """ import logging import math import os import pandas as pd from datetime import datetime from typing import Dict, Any, List, Optional from sqlalchemy.orm import sessionmaker, Session from sqlalchemy import create_engine, and_ from ..database import SessionLocal, engine from ..models.parameters import Parameters from ..models.CropCd_output import CropCdOutputData from ..models.farmland import FarmlandData from ..utils.mapping_utils import MappingUtils from .admin_boundary_service import get_boundary_geojson_by_name import tempfile import json class CdFluxRemovalService: """ Cd通量移除计算服务类 @description: 提供基于CropCd_output_data和Parameters表数据的籽粒移除和秸秆移除Cd通量计算功能 """ def __init__(self): """ 初始化Cd通量移除服务 """ self.logger = logging.getLogger(__name__) def calculate_grain_removal_by_area(self, area: str) -> Dict[str, Any]: """ 根据地区计算籽粒移除Cd通量 @param area: 地区名称 @returns: 计算结果字典 计算公式:籽粒移除(g/ha/a) = EXP(LnCropCd) * F11 * 0.5 * 15 / 1000 """ try: with SessionLocal() as db: # 查询指定地区的参数 parameter = db.query(Parameters).filter(Parameters.area == area).first() if not parameter: return { "success": False, "message": f"未找到地区 '{area}' 的参数数据", "data": None } # 查询CropCd输出数据 crop_cd_outputs = db.query(CropCdOutputData).all() if not crop_cd_outputs: return { "success": False, "message": f"未找到CropCd输出数据", "data": None } # 计算每个样点的籽粒移除Cd通量 results = [] for output in crop_cd_outputs: crop_cd_value = math.exp(output.ln_crop_cd) # EXP(LnCropCd) grain_removal = crop_cd_value * parameter.f11 * 0.5 * 15 / 1000 results.append({ "farmland_id": output.farmland_id, "sample_id": output.sample_id, "ln_crop_cd": output.ln_crop_cd, "crop_cd_value": crop_cd_value, "f11_yield": parameter.f11, "grain_removal_flux": grain_removal }) # 计算统计信息 flux_values = [r["grain_removal_flux"] for r in results] statistics = { "total_samples": len(results), "mean_flux": sum(flux_values) / len(flux_values), "max_flux": max(flux_values), "min_flux": min(flux_values) } return { "success": True, "message": f"地区 '{area}' 的籽粒移除Cd通量计算成功", "data": { "area": area, "calculation_type": "grain_removal", "formula": "EXP(LnCropCd) * F11 * 0.5 * 15 / 1000", "unit": "g/ha/a", "results": results, "statistics": statistics } } except Exception as e: self.logger.error(f"计算地区 '{area}' 的籽粒移除Cd通量失败: {str(e)}") return { "success": False, "message": f"计算失败: {str(e)}", "data": None } def calculate_straw_removal_by_area(self, area: str) -> Dict[str, Any]: """ 根据地区计算秸秆移除Cd通量 @param area: 地区名称 @returns: 计算结果字典 计算公式:秸秆移除(g/ha/a) = [EXP(LnCropCd)/(EXP(LnCropCd)*0.76-0.0034)] * F11 * 0.5 * 15 / 1000 """ try: with SessionLocal() as db: # 查询指定地区的参数 parameter = db.query(Parameters).filter(Parameters.area == area).first() if not parameter: return { "success": False, "message": f"未找到地区 '{area}' 的参数数据", "data": None } # 查询CropCd输出数据 crop_cd_outputs = db.query(CropCdOutputData).all() if not crop_cd_outputs: return { "success": False, "message": f"未找到CropCd输出数据", "data": None } # 计算每个样点的秸秆移除Cd通量 results = [] for output in crop_cd_outputs: crop_cd_value = math.exp(output.ln_crop_cd) # EXP(LnCropCd) # 计算分母:EXP(LnCropCd)*0.76-0.0034 denominator = crop_cd_value * 0.76 - 0.0034 # 检查分母是否为零或负数,避免除零错误 if denominator <= 0: self.logger.warning(f"样点 {output.farmland_id}-{output.sample_id} 的分母值为 {denominator},跳过计算") continue # 计算秸秆移除Cd通量 straw_removal = (crop_cd_value / denominator) * parameter.f11 * 0.5 * 15 / 1000 results.append({ "farmland_id": output.farmland_id, "sample_id": output.sample_id, "ln_crop_cd": output.ln_crop_cd, "crop_cd_value": crop_cd_value, "denominator": denominator, "f11_yield": parameter.f11, "straw_removal_flux": straw_removal }) if not results: return { "success": False, "message": "所有样点的计算都因分母值无效而失败", "data": None } # 计算统计信息 flux_values = [r["straw_removal_flux"] for r in results] statistics = { "total_samples": len(results), "mean_flux": sum(flux_values) / len(flux_values), "max_flux": max(flux_values), "min_flux": min(flux_values) } return { "success": True, "message": f"地区 '{area}' 的秸秆移除Cd通量计算成功", "data": { "area": area, "calculation_type": "straw_removal", "formula": "[EXP(LnCropCd)/(EXP(LnCropCd)*0.76-0.0034)] * F11 * 0.5 * 15 / 1000", "unit": "g/ha/a", "results": results, "statistics": statistics } } except Exception as e: self.logger.error(f"计算地区 '{area}' 的秸秆移除Cd通量失败: {str(e)}") return { "success": False, "message": f"计算失败: {str(e)}", "data": None } def export_results_to_csv(self, results_data: Dict[str, Any], output_dir: str = "app/static/cd_flux") -> str: """ 将计算结果导出为CSV文件 @param results_data: 计算结果数据 @param output_dir: 输出目录 @returns: CSV文件路径 """ try: # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) # 生成时间戳 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # 生成文件名 calculation_type = results_data.get("calculation_type", "flux_removal") area = results_data.get("area", "unknown") filename = f"{calculation_type}_{area}_{timestamp}.csv" csv_path = os.path.join(output_dir, filename) # 转换为DataFrame results = results_data.get("results", []) if not results: raise ValueError("没有结果数据可导出") df = pd.DataFrame(results) # 保存CSV文件 df.to_csv(csv_path, index=False, encoding='utf-8-sig') self.logger.info(f"✓ 成功导出结果到: {csv_path}") return csv_path except Exception as e: self.logger.error(f"导出CSV文件失败: {str(e)}") raise def get_coordinates_for_results(self, results_data: Dict[str, Any]) -> List[Dict[str, Any]]: """ 获取结果数据对应的坐标信息 @param results_data: 计算结果数据 @returns: 包含坐标的结果列表 """ try: results = results_data.get("results", []) if not results: return [] # 提取成对键,避免 N 次数据库查询 farmland_sample_pairs = [(r["farmland_id"], r["sample_id"]) for r in results] with SessionLocal() as db: # 使用 farmland_id 分片查询,避免复合 IN 导致的兼容性与参数数量问题 wanted_pairs = set(farmland_sample_pairs) unique_farmland_ids = sorted({fid for fid, _ in wanted_pairs}) def chunk_list(items: List[int], chunk_size: int = 500) -> List[List[int]]: return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)] rows: List[FarmlandData] = [] for id_chunk in chunk_list(unique_farmland_ids, 500): rows.extend( db.query(FarmlandData) .filter(FarmlandData.farmland_id.in_(id_chunk)) .all() ) pair_to_farmland = { (row.farmland_id, row.sample_id): row for row in rows } coordinates_results: List[Dict[str, Any]] = [] for r in results: key = (r["farmland_id"], r["sample_id"]) farmland = pair_to_farmland.get(key) if farmland is None: continue coord_result = { "farmland_id": r["farmland_id"], "sample_id": r["sample_id"], "longitude": farmland.lon, "latitude": farmland.lan, "flux_value": r.get("grain_removal_flux") or r.get("straw_removal_flux") } coord_result.update(r) coordinates_results.append(coord_result) self.logger.info(f"✓ 成功获取 {len(coordinates_results)} 个样点的坐标信息(分片批量查询)") return coordinates_results except Exception as e: self.logger.error(f"获取坐标信息失败: {str(e)}") raise def create_flux_visualization(self, area: str, calculation_type: str, results_with_coords: List[Dict[str, Any]], output_dir: str = "app/static/cd_flux", template_raster: str = "app/static/cd_flux/meanTemp.tif", boundary_shp: str = None, colormap: str = "green_yellow_red_purple", resolution_factor: float = 4.0, enable_interpolation: bool = False, cleanup_intermediate: bool = True) -> Dict[str, str]: """ 创建Cd通量移除可视化图表 @param area: 地区名称 @param calculation_type: 计算类型(grain_removal 或 straw_removal) @param results_with_coords: 包含坐标的结果数据 @param output_dir: 输出目录 @param template_raster: 模板栅格文件路径 @param boundary_shp: 边界shapefile路径 @param colormap: 色彩方案 @param resolution_factor: 分辨率因子 @param enable_interpolation: 是否启用空间插值 @returns: 生成的图片文件路径字典 """ try: if not results_with_coords: raise ValueError("没有包含坐标的结果数据") # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) # 生成时间戳 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # 创建CSV文件用于绘图 csv_filename = f"{calculation_type}_{area}_temp_{timestamp}.csv" csv_path = os.path.join(output_dir, csv_filename) # 准备绘图数据 plot_data = [] for result in results_with_coords: plot_data.append({ "longitude": result["longitude"], "latitude": result["latitude"], "flux_value": result["flux_value"] }) # 保存为CSV df = pd.DataFrame(plot_data) df.to_csv(csv_path, index=False, encoding='utf-8-sig') # 初始化绘图工具 mapper = MappingUtils() # 生成输出文件路径 map_output = os.path.join(output_dir, f"{calculation_type}_{area}_map_{timestamp}") histogram_output = os.path.join(output_dir, f"{calculation_type}_{area}_histogram_{timestamp}") # 检查模板文件是否存在 if not os.path.exists(template_raster): self.logger.warning(f"模板栅格文件不存在: {template_raster}") template_raster = None # 动态获取边界文件 boundary_shp = self._get_boundary_file_for_area(area) if not boundary_shp: self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪") # 创建shapefile shapefile_path = csv_path.replace('.csv', '_points.shp') mapper.csv_to_shapefile(csv_path, shapefile_path, lon_col='longitude', lat_col='latitude', value_col='flux_value') generated_files = {"csv": csv_path, "shapefile": shapefile_path} # 如果有模板栅格文件,创建栅格地图 if template_raster: try: # 创建栅格 raster_path = csv_path.replace('.csv', '_raster.tif') raster_path, stats = mapper.vector_to_raster( shapefile_path, template_raster, raster_path, 'flux_value', resolution_factor=resolution_factor, boundary_shp=boundary_shp, interpolation_method='nearest', enable_interpolation=enable_interpolation ) generated_files["raster"] = raster_path # 创建栅格地图 - 使用英文标题避免中文乱码 title_mapping = { "grain_removal": "Grain Removal Cd Flux", "straw_removal": "Straw Removal Cd Flux" } map_title = title_mapping.get(calculation_type, "Cd Flux Removal") map_file = mapper.create_raster_map( boundary_shp if boundary_shp else None, raster_path, map_output, colormap=colormap, title=map_title, output_size=12, dpi=300, resolution_factor=4.0, enable_interpolation=False, interpolation_method='nearest' ) generated_files["map"] = map_file # 创建直方图 - 使用英文标题避免中文乱码 histogram_title_mapping = { "grain_removal": "Grain Removal Cd Flux Distribution", "straw_removal": "Straw Removal Cd Flux Distribution" } histogram_title = histogram_title_mapping.get(calculation_type, "Cd Flux Distribution") histogram_file = mapper.create_histogram( raster_path, f"{histogram_output}.jpg", title=histogram_title, xlabel='Cd Flux (g/ha/a)', ylabel='Frequency Density' ) generated_files["histogram"] = histogram_file except Exception as viz_error: self.logger.warning(f"栅格可视化创建失败: {str(viz_error)}") # 即使栅格可视化失败,也返回已生成的文件 # 清理中间文件(默认开启,仅保留最终可视化) if cleanup_intermediate: try: self._cleanup_intermediate_files(generated_files, boundary_shp) except Exception as cleanup_err: self.logger.warning(f"中间文件清理失败: {str(cleanup_err)}") self.logger.info(f"✓ 成功创建 {calculation_type} 可视化,生成文件: {list(generated_files.keys())}") return generated_files except Exception as e: self.logger.error(f"创建可视化失败: {str(e)}") raise def _cleanup_intermediate_files(self, generated_files: Dict[str, str], boundary_shp: Optional[str]) -> None: """ 清理中间文件:CSV、Shapefile 及其配套文件、栅格TIFF;若边界为临时目录,则一并删除 """ import shutil import tempfile def _safe_remove(path: str) -> None: try: if path and os.path.exists(path) and os.path.isfile(path): os.remove(path) except Exception: pass # 删除 CSV _safe_remove(generated_files.get("csv")) # 删除栅格 _safe_remove(generated_files.get("raster")) # 删除 Shapefile 全家桶 shp_path = generated_files.get("shapefile") if shp_path: base, _ = os.path.splitext(shp_path) for ext in (".shp", ".shx", ".dbf", ".prj", ".cpg"): _safe_remove(base + ext) # 如果边界文件来自系统临时目录,删除其所在目录 if boundary_shp: temp_root = tempfile.gettempdir() try: if os.path.commonprefix([os.path.abspath(boundary_shp), temp_root]) == temp_root: temp_dir = os.path.dirname(os.path.abspath(boundary_shp)) if os.path.isdir(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) except Exception: pass def _get_boundary_file_for_area(self, area: str) -> Optional[str]: """ 为指定地区获取边界文件 @param area: 地区名称 @returns: 边界文件路径或None """ try: # 首先尝试静态文件路径(只查找该地区专用的边界文件) norm_area = area.strip() base_name = norm_area.replace('市', '').replace('县', '') name_variants = list(dict.fromkeys([ norm_area, base_name, f"{base_name}市", ])) static_boundary_paths = [] for name in name_variants: static_boundary_paths.append(f"app/static/cd_flux/{name}.shp") for path in static_boundary_paths: if os.path.exists(path): self.logger.info(f"找到边界文件: {path}") return path # 优先从数据库获取边界数据(对名称进行多变体匹配,如 “韶关/韶关市”) boundary_path = self._create_boundary_from_database(area) if boundary_path: return boundary_path # 如果都没有找到,记录警告但不使用默认文件 self.logger.warning(f"未找到地区 '{area}' 的专用边界文件,也无法从数据库获取") return None except Exception as e: self.logger.error(f"获取边界文件失败: {str(e)}") return None def _create_boundary_from_database(self, area: str) -> Optional[str]: """ 从数据库获取边界数据并创建临时shapefile @param area: 地区名称 @returns: 临时边界文件路径或None """ try: with SessionLocal() as db: # 生成名称变体,增强匹配鲁棒性 norm_area = area.strip() base_name = norm_area.replace('市', '').replace('县', '') candidates = list(dict.fromkeys([ norm_area, base_name, f"{base_name}市", ])) for candidate in candidates: try: boundary_geojson = get_boundary_geojson_by_name(db, candidate, level="auto") if boundary_geojson: # 创建临时shapefile import geopandas as gpd from shapely.geometry import shape geometry = shape(boundary_geojson["geometry"]) gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry], crs="EPSG:4326") temp_dir = tempfile.mkdtemp() boundary_path = os.path.join(temp_dir, f"{candidate}_boundary.shp") gdf.to_file(boundary_path, driver="ESRI Shapefile") self.logger.info(f"从数据库创建边界文件: {boundary_path}") return boundary_path except Exception as _: # 尝试下一个候选名称 continue except Exception as e: self.logger.warning(f"从数据库创建边界文件失败: {str(e)}") return None