cd_flux_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import os
  2. import csv
  3. import logging
  4. import sys
  5. import pandas as pd
  6. import geopandas as gpd
  7. from shapely.geometry import Point
  8. from pathlib import Path
  9. from typing import Dict, Any
  10. from sqlalchemy import func
  11. from sqlalchemy.orm import Session
  12. from app.database import SessionLocal
  13. from app.models import FarmlandData, FluxCdOutputData, FluxCdInputData
  14. from app.utils.mapping_utils import MappingUtils, csv_to_raster_workflow
  15. # 配置日志
  16. from app.log.logger import get_logger
  17. logger = get_logger(__name__)
  18. def get_base_dir():
  19. """获取基础目录路径(与土地数据处理函数一致)"""
  20. if getattr(sys, 'frozen', False):
  21. # 打包后的可执行文件
  22. return os.path.dirname(sys.executable)
  23. else:
  24. # 脚本运行模式
  25. return os.path.dirname(os.path.abspath(__file__))
  26. def get_static_dir():
  27. """获取静态资源目录"""
  28. base_dir = get_base_dir()
  29. return os.path.join(base_dir, "..", "static", "cd_flux")
  30. def get_default_boundary_shp():
  31. """获取默认边界SHP文件路径"""
  32. static_dir = get_static_dir()
  33. # 尝试几个可能的边界文件路径
  34. possible_paths = [
  35. os.path.join(static_dir, "lechang.shp"),
  36. ]
  37. for path in possible_paths:
  38. if os.path.exists(path):
  39. return path
  40. return None
  41. class FluxCdVisualizationService:
  42. """
  43. 农田Cd通量可视化服务类
  44. @description: 提供基于农田位置和Cd通量数据的空间分布图和直方图生成功能
  45. @version: 1.1.2 (修复绘图问题)
  46. """
  47. def __init__(self, db: Session = None):
  48. """
  49. 初始化可视化服务
  50. @param db: 可选的数据库会话对象
  51. """
  52. self.logger = logger
  53. self.db = db
  54. self.mapper = MappingUtils()
  55. def generate_cd_input_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  56. """
  57. 生成输入镉通量(In_Cd)的空间分布图和直方图
  58. @param output_dir: 输出文件目录
  59. @param boundary_shp: 边界Shapefile文件路径
  60. @return: 包含输出文件路径的字典
  61. """
  62. return self._generate_cd_flux_map(field='in_cd', title_prefix="input",
  63. output_dir=output_dir, boundary_shp=boundary_shp)
  64. def generate_cd_output_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  65. """
  66. 生成输出镉通量(Out_Cd)的空间分布图和直方图
  67. @param output_dir: 输出文件目录
  68. @param boundary_shp: 边界Shapefile文件路径
  69. @return: 包含输出文件路径的字典
  70. """
  71. return self._generate_cd_flux_map(field='out_cd', title_prefix="output",
  72. output_dir=output_dir, boundary_shp=boundary_shp)
  73. def generate_cd_net_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  74. """
  75. 生成净镉通量(Net_Cd)的空间分布图和直方图
  76. @param output_dir: 输出文件目录
  77. @param boundary_shp: 边界Shapefile文件路径
  78. @return: 包含输出文件路径的字典
  79. """
  80. return self._generate_cd_flux_map(field='net_cd', title_prefix="net",
  81. output_dir=output_dir, boundary_shp=boundary_shp)
  82. def generate_end_cd_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  83. """
  84. 生成当年Cd浓度(End_Cd)的空间分布图和直方图
  85. """
  86. return self._generate_cd_flux_map(
  87. field='end_cd',
  88. title_prefix="End",
  89. output_dir=output_dir,
  90. boundary_shp=boundary_shp
  91. )
  92. def _generate_cd_flux_map(self, field: str, title_prefix: str,
  93. output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  94. """
  95. 通用的Cd通量地图生成方法(直接从数据库获取数据)
  96. """
  97. try:
  98. # 字段配置
  99. field_config = {
  100. 'in_cd': {"title": "Input", "unit": "g/ha/a", "label": "input Cd flux"},
  101. 'out_cd': {"title": "Output", "unit": "g/ha/a", "label": "output Cd flux"},
  102. 'net_cd': {"title": "Net", "unit": "g/ha/a", "label": "net Cd flux"},
  103. 'end_cd': {"title": "End", "unit": "mg/kg", "label": "end Cd concentration"}
  104. }
  105. config = field_config.get(field, field_config['in_cd'])
  106. # 获取数据库会话
  107. db = self.db if self.db else SessionLocal()
  108. should_close = self.db is None
  109. # 从数据库查询数据
  110. query = db.query(
  111. FarmlandData.lon,
  112. FarmlandData.lan,
  113. getattr(FluxCdOutputData, field)
  114. ).join(
  115. FluxCdOutputData,
  116. (FarmlandData.farmland_id == FluxCdOutputData.farmland_id) &
  117. (FarmlandData.sample_id == FluxCdOutputData.sample_id)
  118. ).filter(
  119. getattr(FluxCdOutputData, field) != None # 排除空值
  120. )
  121. data = query.all()
  122. if not data:
  123. if should_close:
  124. db.close()
  125. return {
  126. "success": False,
  127. "message": f"数据库中未找到任何{config['title']}Cd数据",
  128. "data": None
  129. }
  130. # 设置输出目录
  131. static_dir = get_static_dir()
  132. if output_dir is None:
  133. output_dir = static_dir
  134. os.makedirs(output_dir, exist_ok=True)
  135. # 设置边界SHP
  136. if boundary_shp is None:
  137. boundary_shp = get_default_boundary_shp()
  138. # 生成CSV文件
  139. base_name = f"fluxcd_{field}"
  140. csv_path = os.path.join(output_dir, f"{base_name}.csv")
  141. self._generate_csv(data, csv_path, field_names=['lon', 'lan', field])
  142. # 使用模板TIFF
  143. template_tif = os.path.join(static_dir, "meanTemp.tif")
  144. if not os.path.exists(template_tif):
  145. raise FileNotFoundError("未找到模板TIFF文件")
  146. # 转换为栅格
  147. raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
  148. workflow_result = csv_to_raster_workflow(
  149. csv_file=csv_path,
  150. template_tif=template_tif,
  151. output_dir=output_dir,
  152. resolution_factor=4.0,
  153. interpolation_method='linear',
  154. field_name=field,
  155. lon_col=0,
  156. lat_col=1,
  157. value_col=2,
  158. enable_interpolation=True
  159. )
  160. raster_path = workflow_result['raster']
  161. stats = workflow_result['statistics']
  162. # 创建地图
  163. map_path = os.path.join(output_dir, f"{base_name}_map")
  164. map_file = self.mapper.create_raster_map(
  165. shp_path=boundary_shp,
  166. tif_path=raster_path,
  167. output_path=map_path,
  168. colormap='green_yellow_red_purple',
  169. title=f"{config['title']} Cd {field.split('_')[0]} map",
  170. output_size=12,
  171. dpi=300,
  172. enable_interpolation=False
  173. )
  174. # 创建直方图
  175. histogram_path = self.mapper.create_histogram(
  176. raster_path,
  177. save_path=os.path.join(output_dir, f"{base_name}_histogram.jpg"),
  178. xlabel=f"{config['label']} ({config['unit']})",
  179. ylabel='frequency',
  180. title=f"{config['title']} Cd {field.split('_')[0]} histogram",
  181. bins=100
  182. )
  183. result = {
  184. "success": True,
  185. "message": f"成功生成{config['title']}Cd数据可视化结果",
  186. "data": {
  187. "csv": csv_path,
  188. "raster": raster_path,
  189. "map": map_file,
  190. "histogram": histogram_path,
  191. "statistics": stats,
  192. "boundary_used": boundary_shp if boundary_shp else "无"
  193. }
  194. }
  195. if should_close:
  196. db.close()
  197. return result
  198. except Exception as e:
  199. self.logger.error(f"生成Cd数据可视化结果时发生错误: {str(e)}", exc_info=True)
  200. return {
  201. "success": False,
  202. "message": f"生成失败: {str(e)}",
  203. "data": None
  204. }
  205. def update_from_csv(self, csv_file_path: str):
  206. """包装更新服务方法"""
  207. update_service = FluxCdUpdateService(db=self.db)
  208. return update_service.update_from_csv(csv_file_path)
  209. def _fetch_fluxcd_data(self, db: Session, field: str = 'in_cd') -> list:
  210. """
  211. 从数据库查询需要的数据
  212. @param db: 数据库会话
  213. @param field: 要查询的字段('in_cd', 'out_cd'或'net_cd')
  214. @returns: 查询结果列表
  215. """
  216. try:
  217. # 查询Farmland_data和FluxCd_output_data的联合数据
  218. query = db.query(
  219. FarmlandData.lon,
  220. FarmlandData.lan,
  221. getattr(FluxCdOutputData, field)
  222. ).join(
  223. FluxCdOutputData,
  224. (FarmlandData.farmland_id == FluxCdOutputData.farmland_id) &
  225. (FarmlandData.sample_id == FluxCdOutputData.sample_id)
  226. )
  227. return query.all()
  228. except Exception as e:
  229. self.logger.error(f"查询农田Cd通量数据时发生错误: {str(e)}")
  230. return []
  231. def _generate_csv(self, data: list, output_path: str, field_names: list = None):
  232. """
  233. 将查询结果生成CSV文件
  234. @param data: 查询结果列表
  235. @param output_path: 输出CSV文件路径
  236. @param field_names: CSV列名列表
  237. """
  238. if field_names is None:
  239. field_names = ['lon', 'lan', 'In_Cd']
  240. try:
  241. with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
  242. writer = csv.writer(csvfile)
  243. # 写入表头
  244. writer.writerow(field_names)
  245. # 写入数据
  246. for row in data:
  247. writer.writerow(row)
  248. self.logger.info(f"CSV文件已生成: {output_path}")
  249. except Exception as e:
  250. self.logger.error(f"生成CSV文件时发生错误: {str(e)}")
  251. raise
  252. class FluxCdUpdateService:
  253. """
  254. 农田Cd通量数据更新服务类
  255. 处理用户上传的CSV文件并更新数据库中的通量值
  256. """
  257. def __init__(self, db: Session = None):
  258. self.db = db if db else SessionLocal()
  259. self.logger = logging.getLogger(__name__)
  260. def update_from_csv(self, csv_file_path: str) -> Dict[str, Any]:
  261. """
  262. 从CSV文件更新Cd通量数据
  263. @param csv_file_path: CSV文件路径
  264. @return: 更新结果字典
  265. """
  266. try:
  267. # 读取CSV文件
  268. df = self._read_csv(csv_file_path)
  269. # 验证CSV格式
  270. self._validate_csv(df)
  271. # 更新数据库
  272. update_count = 0
  273. create_count = 0
  274. error_count = 0
  275. for _, row in df.iterrows():
  276. try:
  277. # 查找匹配的农田样点
  278. farmland = self._find_farmland(row['lon'], row['lan'])
  279. if farmland:
  280. # 更新现有记录
  281. input_updated = self._update_input_data(farmland, row)
  282. if input_updated:
  283. self._update_output_data(farmland)
  284. update_count += 1
  285. else:
  286. # 创建新记录
  287. if self._create_new_records(row):
  288. create_count += 1
  289. else:
  290. error_count += 1
  291. except Exception as e:
  292. self.logger.error(f"处理记录失败: {str(e)}")
  293. error_count += 1
  294. self.db.commit()
  295. return {
  296. "success": True,
  297. "message": f"成功更新 {update_count} 条记录,创建 {create_count} 条新记录,{error_count} 条失败",
  298. "updated_count": update_count,
  299. "created_count": create_count,
  300. "error_count": error_count
  301. }
  302. except Exception as e:
  303. self.db.rollback()
  304. self.logger.error(f"更新失败: {str(e)}", exc_info=True)
  305. return {
  306. "success": False,
  307. "message": f"更新失败: {str(e)}"
  308. }
  309. finally:
  310. self.db.close()
  311. def _create_new_records(self, row: pd.Series):
  312. """为新的经纬度创建完整的记录"""
  313. try:
  314. self.logger.info(f"创建新农田记录: 经度={row['lon']}, 纬度={row['lan']}")
  315. # 获取下一个可用的 Farmland_ID
  316. max_id = self.db.query(func.max(FarmlandData.farmland_id)).scalar()
  317. new_farmland_id = (max_id or 0) + 1
  318. # 创建新的农田样点记录
  319. farmland = FarmlandData(
  320. farmland_id=new_farmland_id,
  321. lon=row['lon'],
  322. lan=row['lan'],
  323. type=0.0, # 默认为旱地
  324. geom=f"POINT({row['lon']} {row['lan']})"
  325. )
  326. self.db.add(farmland)
  327. self.db.flush() # 确保ID被分配
  328. self.logger.info(f"农田记录创建成功: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  329. # 创建新的输入通量记录
  330. input_data = FluxCdInputData(
  331. farmland_id=farmland.farmland_id,
  332. sample_id=farmland.sample_id,
  333. atmospheric_deposition=row['DQCJ_Cd'],
  334. irrigation_input=row['GGS_Cd'],
  335. agro_chemicals_input=row['NCP_Cd'],
  336. # 设置合理的默认值
  337. initial_cd=0.0,
  338. groundwater_leaching=row.get('DX_Cd', 0.023), # 添加默认值处理
  339. surface_runoff=row.get('DB_Cd', 0.368),
  340. grain_removal=row.get('ZL_Cd', 0.0),
  341. straw_removal=row.get('JG_Cd', 0.0)
  342. )
  343. self.db.add(input_data)
  344. self.logger.info("输入通量记录创建成功")
  345. # 计算输入总通量
  346. in_cd = row['DQCJ_Cd'] + row['GGS_Cd'] + row['NCP_Cd']
  347. # 计算输出总通量
  348. out_cd = (input_data.groundwater_leaching +
  349. input_data.surface_runoff +
  350. input_data.grain_removal +
  351. input_data.straw_removal)
  352. # 创建新的输出通量记录
  353. output_data = FluxCdOutputData(
  354. farmland_id=farmland.farmland_id,
  355. sample_id=farmland.sample_id,
  356. in_cd=in_cd,
  357. out_cd=out_cd,
  358. net_cd=in_cd - out_cd,
  359. end_cd=0.0 # 默认当年Cd浓度
  360. )
  361. self.db.add(output_data)
  362. self.logger.info("输出通量记录创建成功")
  363. self.logger.info(f"所有记录创建完成: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  364. return True
  365. except Exception as e:
  366. self.logger.error(f"创建新记录失败: {str(e)}")
  367. return False
  368. def _read_csv(self, file_path: str) -> pd.DataFrame:
  369. """读取CSV文件并验证基本格式"""
  370. try:
  371. df = pd.read_csv(file_path)
  372. self.logger.info(f"成功读取CSV文件: {file_path}")
  373. return df
  374. except Exception as e:
  375. raise ValueError(f"CSV文件读取失败: {str(e)}")
  376. def _validate_csv(self, df: pd.DataFrame):
  377. """验证CSV文件包含必要的列"""
  378. required_columns = {'lon', 'lan', 'DQCJ_Cd', 'GGS_Cd', 'NCP_Cd'}
  379. # 添加输出通量可选字段
  380. output_fields = {'DX_Cd', 'DB_Cd', 'ZL_Cd', 'JG_Cd'}
  381. if not required_columns.issubset(df.columns):
  382. missing = required_columns - set(df.columns)
  383. raise ValueError(f"CSV缺少必要列: {', '.join(missing)}")
  384. # 检查是否包含输出通量字段(可选)
  385. if not output_fields.issubset(df.columns):
  386. missing_output = output_fields - set(df.columns)
  387. self.logger.warning(f"CSV缺少输出通量字段: {', '.join(missing_output)},将使用默认值")
  388. def _find_farmland(self, lon: float, lan: float) -> FarmlandData:
  389. """根据经纬度查找农田样点"""
  390. # 使用容差匹配(0.001度≈100米)
  391. tol = 0.001
  392. return self.db.query(FarmlandData).filter(
  393. FarmlandData.lon.between(lon - tol, lon + tol),
  394. FarmlandData.lan.between(lan - tol, lan + tol)
  395. ).first()
  396. def _update_input_data(self, farmland: FarmlandData, row: pd.Series) -> bool:
  397. """更新输入通量数据,返回是否有更新"""
  398. input_data = self.db.query(FluxCdInputData).filter(
  399. FluxCdInputData.farmland_id == farmland.farmland_id,
  400. FluxCdInputData.sample_id == farmland.sample_id
  401. ).first()
  402. if not input_data:
  403. self.logger.warning(f"未找到输入数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  404. return False
  405. # 检查是否需要更新
  406. updated = False
  407. if input_data.atmospheric_deposition != row['DQCJ_Cd']:
  408. input_data.atmospheric_deposition = row['DQCJ_Cd']
  409. updated = True
  410. if input_data.irrigation_input != row['GGS_Cd']:
  411. input_data.irrigation_input = row['GGS_Cd']
  412. updated = True
  413. if input_data.agro_chemicals_input != row['NCP_Cd']:
  414. input_data.agro_chemicals_input = row['NCP_Cd']
  415. updated = True
  416. if updated:
  417. self.logger.info(f"更新输入通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  418. return updated
  419. def _update_output_data(self, farmland: FarmlandData):
  420. """更新输出通量数据"""
  421. output_data = self.db.query(FluxCdOutputData).filter(
  422. FluxCdOutputData.farmland_id == farmland.farmland_id,
  423. FluxCdOutputData.sample_id == farmland.sample_id
  424. ).first()
  425. if not output_data:
  426. self.logger.warning(f"未找到输出数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  427. return
  428. # 重新获取更新后的输入数据
  429. input_data = self.db.query(FluxCdInputData).filter(
  430. FluxCdInputData.farmland_id == farmland.farmland_id,
  431. FluxCdInputData.sample_id == farmland.sample_id
  432. ).first()
  433. # 计算输出总通量
  434. out_cd = (input_data.groundwater_leaching + # DX_Cd
  435. input_data.surface_runoff + # DB_Cd
  436. input_data.grain_removal + # ZL_Cd
  437. input_data.straw_removal) # JG_Cd
  438. # 更新输出通量记录
  439. output_data.in_cd = input_data.input_flux()
  440. output_data.out_cd = out_cd
  441. output_data.net_cd = output_data.in_cd - out_cd
  442. # 新增:计算并更新当年Cd浓度
  443. try:
  444. output_data.calculate_end_cd(self.db)
  445. self.logger.info(f"成功更新当年Cd浓度: {output_data.end_cd} mg/kg")
  446. except Exception as e:
  447. self.logger.error(f"更新当年Cd浓度失败: {str(e)}")
  448. self.logger.info(f"更新输出通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  449. # 测试主函数
  450. # 测试主函数
  451. if __name__ == "__main__":
  452. import tempfile
  453. import shutil
  454. import json
  455. # 创建数据库会话
  456. db = SessionLocal()
  457. try:
  458. # 初始化服务
  459. service = FluxCdVisualizationService(db=db)
  460. print("=" * 50)
  461. print("测试Cd通量可视化服务")
  462. print("=" * 50)
  463. # 测试生成可视化结果
  464. print("\n>>> 测试生成Cd输入通量可视化地图")
  465. input_result = service.generate_cd_input_flux_map()
  466. print(json.dumps(input_result, indent=2, ensure_ascii=False))
  467. print("\n>>> 测试生成Cd输出通量可视化地图")
  468. output_result = service.generate_cd_output_flux_map()
  469. print(json.dumps(output_result, indent=2, ensure_ascii=False))
  470. print("\n>>> 测试生成Cd净通量可视化地图")
  471. net_result = service.generate_cd_net_flux_map()
  472. print(json.dumps(net_result, indent=2, ensure_ascii=False))
  473. # 测试更新服务
  474. print("\n>>> 测试从CSV更新Cd通量数据")
  475. # 创建临时目录和CSV文件
  476. temp_dir = tempfile.mkdtemp()
  477. test_csv_path = os.path.join(temp_dir, "test_update.csv")
  478. # 创建测试CSV文件
  479. test_data = [
  480. "lon,lan,DQCJ_Cd,GGS_Cd,NCP_Cd,DX_Cd,DB_Cd,ZL_Cd,JG_Cd",
  481. "113.123,25.456,1.24,4.56,7.89,0.1,0.2,0.05,0.03",
  482. "113.125,25.457,2.35,5.67,8.90,0.15,0.25,0.06,0.04"
  483. ]
  484. with open(test_csv_path, 'w', encoding='utf-8') as f:
  485. f.write("\n".join(test_data))
  486. # 调用更新服务
  487. update_result = service.update_from_csv(test_csv_path)
  488. print(json.dumps(update_result, indent=2, ensure_ascii=False))
  489. # 清理临时文件
  490. shutil.rmtree(temp_dir)
  491. print(f"\n清理临时目录: {temp_dir}")
  492. print("\n测试完成!")
  493. except Exception as e:
  494. logger.error(f"测试过程中发生错误: {str(e)}")
  495. import traceback
  496. traceback.print_exc()
  497. finally:
  498. # 关闭数据库会话
  499. db.close()