cd_flux_removal_service.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. """
  2. Cd通量移除计算服务
  3. @description: 提供籽粒移除和秸秆移除的Cd通量计算功能
  4. @author: AcidMap Team
  5. @version: 1.0.0
  6. """
  7. import logging
  8. import math
  9. import os
  10. import pandas as pd
  11. from datetime import datetime
  12. from typing import Dict, Any, List, Optional
  13. from sqlalchemy.orm import sessionmaker, Session
  14. from sqlalchemy import create_engine, and_
  15. from ..database import SessionLocal, engine
  16. from ..models.parameters import Parameters
  17. from ..models.CropCd_output import CropCdOutputData
  18. from ..models.farmland import FarmlandData
  19. from ..utils.mapping_utils import MappingUtils
  20. from .admin_boundary_service import get_boundary_geojson_by_name
  21. import tempfile
  22. import json
  23. class CdFluxRemovalService:
  24. """
  25. Cd通量移除计算服务类
  26. @description: 提供基于CropCd_output_data和Parameters表数据的籽粒移除和秸秆移除Cd通量计算功能
  27. """
  28. def __init__(self):
  29. """
  30. 初始化Cd通量移除服务
  31. """
  32. self.logger = logging.getLogger(__name__)
  33. def calculate_grain_removal_by_area(self, area: str) -> Dict[str, Any]:
  34. """
  35. 根据地区计算籽粒移除Cd通量
  36. @param area: 地区名称
  37. @returns: 计算结果字典
  38. 计算公式:籽粒移除(g/ha/a) = EXP(LnCropCd) * F11 * 0.5 * 15 / 1000
  39. """
  40. try:
  41. with SessionLocal() as db:
  42. # 查询指定地区的参数
  43. parameter = db.query(Parameters).filter(Parameters.area == area).first()
  44. if not parameter:
  45. return {
  46. "success": False,
  47. "message": f"未找到地区 '{area}' 的参数数据",
  48. "data": None
  49. }
  50. # 查询CropCd输出数据
  51. crop_cd_outputs = db.query(CropCdOutputData).all()
  52. if not crop_cd_outputs:
  53. return {
  54. "success": False,
  55. "message": f"未找到CropCd输出数据",
  56. "data": None
  57. }
  58. # 计算每个样点的籽粒移除Cd通量
  59. results = []
  60. for output in crop_cd_outputs:
  61. crop_cd_value = math.exp(output.ln_crop_cd) # EXP(LnCropCd)
  62. grain_removal = crop_cd_value * parameter.f11 * 0.5 * 15 / 1000
  63. results.append({
  64. "farmland_id": output.farmland_id,
  65. "sample_id": output.sample_id,
  66. "ln_crop_cd": output.ln_crop_cd,
  67. "crop_cd_value": crop_cd_value,
  68. "f11_yield": parameter.f11,
  69. "grain_removal_flux": grain_removal
  70. })
  71. # 计算统计信息
  72. flux_values = [r["grain_removal_flux"] for r in results]
  73. statistics = {
  74. "total_samples": len(results),
  75. "mean_flux": sum(flux_values) / len(flux_values),
  76. "max_flux": max(flux_values),
  77. "min_flux": min(flux_values)
  78. }
  79. return {
  80. "success": True,
  81. "message": f"地区 '{area}' 的籽粒移除Cd通量计算成功",
  82. "data": {
  83. "area": area,
  84. "calculation_type": "grain_removal",
  85. "formula": "EXP(LnCropCd) * F11 * 0.5 * 15 / 1000",
  86. "unit": "g/ha/a",
  87. "results": results,
  88. "statistics": statistics
  89. }
  90. }
  91. except Exception as e:
  92. self.logger.error(f"计算地区 '{area}' 的籽粒移除Cd通量失败: {str(e)}")
  93. return {
  94. "success": False,
  95. "message": f"计算失败: {str(e)}",
  96. "data": None
  97. }
  98. def calculate_straw_removal_by_area(self, area: str) -> Dict[str, Any]:
  99. """
  100. 根据地区计算秸秆移除Cd通量
  101. @param area: 地区名称
  102. @returns: 计算结果字典
  103. 计算公式:秸秆移除(g/ha/a) = [EXP(LnCropCd)/(EXP(LnCropCd)*0.76-0.0034)] * F11 * 0.5 * 15 / 1000
  104. """
  105. try:
  106. with SessionLocal() as db:
  107. # 查询指定地区的参数
  108. parameter = db.query(Parameters).filter(Parameters.area == area).first()
  109. if not parameter:
  110. return {
  111. "success": False,
  112. "message": f"未找到地区 '{area}' 的参数数据",
  113. "data": None
  114. }
  115. # 查询CropCd输出数据
  116. crop_cd_outputs = db.query(CropCdOutputData).all()
  117. if not crop_cd_outputs:
  118. return {
  119. "success": False,
  120. "message": f"未找到CropCd输出数据",
  121. "data": None
  122. }
  123. # 计算每个样点的秸秆移除Cd通量
  124. results = []
  125. for output in crop_cd_outputs:
  126. crop_cd_value = math.exp(output.ln_crop_cd) # EXP(LnCropCd)
  127. # 计算分母:EXP(LnCropCd)*0.76-0.0034
  128. denominator = crop_cd_value * 0.76 - 0.0034
  129. # 检查分母是否为零或负数,避免除零错误
  130. if denominator <= 0:
  131. self.logger.warning(f"样点 {output.farmland_id}-{output.sample_id} 的分母值为 {denominator},跳过计算")
  132. continue
  133. # 计算秸秆移除Cd通量
  134. straw_removal = (crop_cd_value / denominator) * parameter.f11 * 0.5 * 15 / 1000
  135. results.append({
  136. "farmland_id": output.farmland_id,
  137. "sample_id": output.sample_id,
  138. "ln_crop_cd": output.ln_crop_cd,
  139. "crop_cd_value": crop_cd_value,
  140. "denominator": denominator,
  141. "f11_yield": parameter.f11,
  142. "straw_removal_flux": straw_removal
  143. })
  144. if not results:
  145. return {
  146. "success": False,
  147. "message": "所有样点的计算都因分母值无效而失败",
  148. "data": None
  149. }
  150. # 计算统计信息
  151. flux_values = [r["straw_removal_flux"] for r in results]
  152. statistics = {
  153. "total_samples": len(results),
  154. "mean_flux": sum(flux_values) / len(flux_values),
  155. "max_flux": max(flux_values),
  156. "min_flux": min(flux_values)
  157. }
  158. return {
  159. "success": True,
  160. "message": f"地区 '{area}' 的秸秆移除Cd通量计算成功",
  161. "data": {
  162. "area": area,
  163. "calculation_type": "straw_removal",
  164. "formula": "[EXP(LnCropCd)/(EXP(LnCropCd)*0.76-0.0034)] * F11 * 0.5 * 15 / 1000",
  165. "unit": "g/ha/a",
  166. "results": results,
  167. "statistics": statistics
  168. }
  169. }
  170. except Exception as e:
  171. self.logger.error(f"计算地区 '{area}' 的秸秆移除Cd通量失败: {str(e)}")
  172. return {
  173. "success": False,
  174. "message": f"计算失败: {str(e)}",
  175. "data": None
  176. }
  177. def export_results_to_csv(self, results_data: Dict[str, Any], output_dir: str = "app/static/cd_flux") -> str:
  178. """
  179. 将计算结果导出为CSV文件
  180. @param results_data: 计算结果数据
  181. @param output_dir: 输出目录
  182. @returns: CSV文件路径
  183. """
  184. try:
  185. # 确保输出目录存在
  186. os.makedirs(output_dir, exist_ok=True)
  187. # 生成时间戳
  188. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  189. # 生成文件名
  190. calculation_type = results_data.get("calculation_type", "flux_removal")
  191. area = results_data.get("area", "unknown")
  192. filename = f"{calculation_type}_{area}_{timestamp}.csv"
  193. csv_path = os.path.join(output_dir, filename)
  194. # 转换为DataFrame
  195. results = results_data.get("results", [])
  196. if not results:
  197. raise ValueError("没有结果数据可导出")
  198. df = pd.DataFrame(results)
  199. # 保存CSV文件
  200. df.to_csv(csv_path, index=False, encoding='utf-8-sig')
  201. self.logger.info(f"✓ 成功导出结果到: {csv_path}")
  202. return csv_path
  203. except Exception as e:
  204. self.logger.error(f"导出CSV文件失败: {str(e)}")
  205. raise
  206. def get_coordinates_for_results(self, results_data: Dict[str, Any]) -> List[Dict[str, Any]]:
  207. """
  208. 获取结果数据对应的坐标信息
  209. @param results_data: 计算结果数据
  210. @returns: 包含坐标的结果列表
  211. """
  212. try:
  213. results = results_data.get("results", [])
  214. if not results:
  215. return []
  216. # 提取成对键,避免 N 次数据库查询
  217. farmland_sample_pairs = [(r["farmland_id"], r["sample_id"]) for r in results]
  218. with SessionLocal() as db:
  219. # 使用 farmland_id 分片查询,避免复合 IN 导致的兼容性与参数数量问题
  220. wanted_pairs = set(farmland_sample_pairs)
  221. unique_farmland_ids = sorted({fid for fid, _ in wanted_pairs})
  222. def chunk_list(items: List[int], chunk_size: int = 500) -> List[List[int]]:
  223. return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
  224. rows: List[FarmlandData] = []
  225. for id_chunk in chunk_list(unique_farmland_ids, 500):
  226. rows.extend(
  227. db.query(FarmlandData)
  228. .filter(FarmlandData.farmland_id.in_(id_chunk))
  229. .all()
  230. )
  231. pair_to_farmland = {
  232. (row.farmland_id, row.sample_id): row for row in rows
  233. }
  234. coordinates_results: List[Dict[str, Any]] = []
  235. for r in results:
  236. key = (r["farmland_id"], r["sample_id"])
  237. farmland = pair_to_farmland.get(key)
  238. if farmland is None:
  239. continue
  240. coord_result = {
  241. "farmland_id": r["farmland_id"],
  242. "sample_id": r["sample_id"],
  243. "longitude": farmland.lon,
  244. "latitude": farmland.lan,
  245. "flux_value": r.get("grain_removal_flux") or r.get("straw_removal_flux")
  246. }
  247. coord_result.update(r)
  248. coordinates_results.append(coord_result)
  249. self.logger.info(f"✓ 成功获取 {len(coordinates_results)} 个样点的坐标信息(分片批量查询)")
  250. return coordinates_results
  251. except Exception as e:
  252. self.logger.error(f"获取坐标信息失败: {str(e)}")
  253. raise
  254. def create_flux_visualization(self, area: str, calculation_type: str,
  255. results_with_coords: List[Dict[str, Any]],
  256. output_dir: str = "app/static/cd_flux",
  257. template_raster: str = "app/static/cd_flux/meanTemp.tif",
  258. boundary_shp: str = None,
  259. colormap: str = "green_yellow_red_purple",
  260. resolution_factor: float = 4.0,
  261. enable_interpolation: bool = False,
  262. cleanup_intermediate: bool = True) -> Dict[str, str]:
  263. """
  264. 创建Cd通量移除可视化图表
  265. @param area: 地区名称
  266. @param calculation_type: 计算类型(grain_removal 或 straw_removal)
  267. @param results_with_coords: 包含坐标的结果数据
  268. @param output_dir: 输出目录
  269. @param template_raster: 模板栅格文件路径
  270. @param boundary_shp: 边界shapefile路径
  271. @param colormap: 色彩方案
  272. @param resolution_factor: 分辨率因子
  273. @param enable_interpolation: 是否启用空间插值
  274. @returns: 生成的图片文件路径字典
  275. """
  276. try:
  277. if not results_with_coords:
  278. raise ValueError("没有包含坐标的结果数据")
  279. # 确保输出目录存在
  280. os.makedirs(output_dir, exist_ok=True)
  281. # 生成时间戳
  282. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  283. # 创建CSV文件用于绘图
  284. csv_filename = f"{calculation_type}_{area}_temp_{timestamp}.csv"
  285. csv_path = os.path.join(output_dir, csv_filename)
  286. # 准备绘图数据
  287. plot_data = []
  288. for result in results_with_coords:
  289. plot_data.append({
  290. "longitude": result["longitude"],
  291. "latitude": result["latitude"],
  292. "flux_value": result["flux_value"]
  293. })
  294. # 保存为CSV
  295. df = pd.DataFrame(plot_data)
  296. df.to_csv(csv_path, index=False, encoding='utf-8-sig')
  297. # 初始化绘图工具
  298. mapper = MappingUtils()
  299. # 生成输出文件路径
  300. map_output = os.path.join(output_dir, f"{calculation_type}_{area}_map_{timestamp}")
  301. histogram_output = os.path.join(output_dir, f"{calculation_type}_{area}_histogram_{timestamp}")
  302. # 检查模板文件是否存在
  303. if not os.path.exists(template_raster):
  304. self.logger.warning(f"模板栅格文件不存在: {template_raster}")
  305. template_raster = None
  306. # 动态获取边界文件
  307. boundary_shp = self._get_boundary_file_for_area(area)
  308. if not boundary_shp:
  309. self.logger.warning(f"未找到地区 '{area}' 的边界文件,将不使用边界裁剪")
  310. # 创建shapefile
  311. shapefile_path = csv_path.replace('.csv', '_points.shp')
  312. mapper.csv_to_shapefile(csv_path, shapefile_path,
  313. lon_col='longitude', lat_col='latitude', value_col='flux_value')
  314. generated_files = {"csv": csv_path, "shapefile": shapefile_path}
  315. # 如果有模板栅格文件,创建栅格地图
  316. if template_raster:
  317. try:
  318. # 创建栅格
  319. raster_path = csv_path.replace('.csv', '_raster.tif')
  320. raster_path, stats = mapper.vector_to_raster(
  321. shapefile_path, template_raster, raster_path, 'flux_value',
  322. resolution_factor=resolution_factor, boundary_shp=boundary_shp,
  323. interpolation_method='nearest', enable_interpolation=enable_interpolation
  324. )
  325. generated_files["raster"] = raster_path
  326. # 创建栅格地图 - 使用英文标题避免中文乱码
  327. title_mapping = {
  328. "grain_removal": "Grain Removal Cd Flux",
  329. "straw_removal": "Straw Removal Cd Flux"
  330. }
  331. map_title = title_mapping.get(calculation_type, "Cd Flux Removal")
  332. map_file = mapper.create_raster_map(
  333. boundary_shp if boundary_shp else None,
  334. raster_path,
  335. map_output,
  336. colormap=colormap,
  337. title=map_title,
  338. output_size=12,
  339. dpi=300,
  340. resolution_factor=4.0,
  341. enable_interpolation=False,
  342. interpolation_method='nearest'
  343. )
  344. generated_files["map"] = map_file
  345. # 创建直方图 - 使用英文标题避免中文乱码
  346. histogram_title_mapping = {
  347. "grain_removal": "Grain Removal Cd Flux Distribution",
  348. "straw_removal": "Straw Removal Cd Flux Distribution"
  349. }
  350. histogram_title = histogram_title_mapping.get(calculation_type, "Cd Flux Distribution")
  351. histogram_file = mapper.create_histogram(
  352. raster_path,
  353. f"{histogram_output}.jpg",
  354. title=histogram_title,
  355. xlabel='Cd Flux (g/ha/a)',
  356. ylabel='Frequency Density'
  357. )
  358. generated_files["histogram"] = histogram_file
  359. except Exception as viz_error:
  360. self.logger.warning(f"栅格可视化创建失败: {str(viz_error)}")
  361. # 即使栅格可视化失败,也返回已生成的文件
  362. # 清理中间文件(默认开启,仅保留最终可视化)
  363. if cleanup_intermediate:
  364. try:
  365. self._cleanup_intermediate_files(generated_files, boundary_shp)
  366. except Exception as cleanup_err:
  367. self.logger.warning(f"中间文件清理失败: {str(cleanup_err)}")
  368. self.logger.info(f"✓ 成功创建 {calculation_type} 可视化,生成文件: {list(generated_files.keys())}")
  369. return generated_files
  370. except Exception as e:
  371. self.logger.error(f"创建可视化失败: {str(e)}")
  372. raise
  373. def _cleanup_intermediate_files(self, generated_files: Dict[str, str], boundary_shp: Optional[str]) -> None:
  374. """
  375. 清理中间文件:CSV、Shapefile 及其配套文件、栅格TIFF;若边界为临时目录,则一并删除
  376. """
  377. import shutil
  378. import tempfile
  379. def _safe_remove(path: str) -> None:
  380. try:
  381. if path and os.path.exists(path) and os.path.isfile(path):
  382. os.remove(path)
  383. except Exception:
  384. pass
  385. # 删除 CSV
  386. _safe_remove(generated_files.get("csv"))
  387. # 删除栅格
  388. _safe_remove(generated_files.get("raster"))
  389. # 删除 Shapefile 全家桶
  390. shp_path = generated_files.get("shapefile")
  391. if shp_path:
  392. base, _ = os.path.splitext(shp_path)
  393. for ext in (".shp", ".shx", ".dbf", ".prj", ".cpg"):
  394. _safe_remove(base + ext)
  395. # 如果边界文件来自系统临时目录,删除其所在目录
  396. if boundary_shp:
  397. temp_root = tempfile.gettempdir()
  398. try:
  399. if os.path.commonprefix([os.path.abspath(boundary_shp), temp_root]) == temp_root:
  400. temp_dir = os.path.dirname(os.path.abspath(boundary_shp))
  401. if os.path.isdir(temp_dir):
  402. shutil.rmtree(temp_dir, ignore_errors=True)
  403. except Exception:
  404. pass
  405. def _get_boundary_file_for_area(self, area: str) -> Optional[str]:
  406. """
  407. 为指定地区获取边界文件
  408. @param area: 地区名称
  409. @returns: 边界文件路径或None
  410. """
  411. try:
  412. # 首先尝试静态文件路径(只查找该地区专用的边界文件)
  413. norm_area = area.strip()
  414. base_name = norm_area.replace('市', '').replace('县', '')
  415. name_variants = list(dict.fromkeys([
  416. norm_area,
  417. base_name,
  418. f"{base_name}市",
  419. ]))
  420. static_boundary_paths = []
  421. for name in name_variants:
  422. static_boundary_paths.append(f"app/static/cd_flux/{name}.shp")
  423. for path in static_boundary_paths:
  424. if os.path.exists(path):
  425. self.logger.info(f"找到边界文件: {path}")
  426. return path
  427. # 优先从数据库获取边界数据(对名称进行多变体匹配,如 “韶关/韶关市”)
  428. boundary_path = self._create_boundary_from_database(area)
  429. if boundary_path:
  430. return boundary_path
  431. # 如果都没有找到,记录警告但不使用默认文件
  432. self.logger.warning(f"未找到地区 '{area}' 的专用边界文件,也无法从数据库获取")
  433. return None
  434. except Exception as e:
  435. self.logger.error(f"获取边界文件失败: {str(e)}")
  436. return None
  437. def _create_boundary_from_database(self, area: str) -> Optional[str]:
  438. """
  439. 从数据库获取边界数据并创建临时shapefile
  440. @param area: 地区名称
  441. @returns: 临时边界文件路径或None
  442. """
  443. try:
  444. with SessionLocal() as db:
  445. # 生成名称变体,增强匹配鲁棒性
  446. norm_area = area.strip()
  447. base_name = norm_area.replace('市', '').replace('县', '')
  448. candidates = list(dict.fromkeys([
  449. norm_area,
  450. base_name,
  451. f"{base_name}市",
  452. ]))
  453. for candidate in candidates:
  454. try:
  455. boundary_geojson = get_boundary_geojson_by_name(db, candidate, level="auto")
  456. if boundary_geojson:
  457. # 创建临时shapefile
  458. import geopandas as gpd
  459. from shapely.geometry import shape
  460. geometry = shape(boundary_geojson["geometry"])
  461. gdf = gpd.GeoDataFrame([boundary_geojson["properties"]], geometry=[geometry], crs="EPSG:4326")
  462. temp_dir = tempfile.mkdtemp()
  463. boundary_path = os.path.join(temp_dir, f"{candidate}_boundary.shp")
  464. gdf.to_file(boundary_path, driver="ESRI Shapefile")
  465. self.logger.info(f"从数据库创建边界文件: {boundary_path}")
  466. return boundary_path
  467. except Exception as _:
  468. # 尝试下一个候选名称
  469. continue
  470. except Exception as e:
  471. self.logger.warning(f"从数据库创建边界文件失败: {str(e)}")
  472. return None