cd_flux_service.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import os
  2. import csv
  3. import logging
  4. import sys
  5. import pandas as pd
  6. import geopandas as gpd
  7. from shapely.geometry import Point
  8. from pathlib import Path
  9. from typing import Dict, Any
  10. from sqlalchemy import func
  11. from sqlalchemy.orm import Session
  12. from app.database import SessionLocal
  13. from app.models import FarmlandData, FluxCdOutputData, FluxCdInputData
  14. from app.utils.mapping_utils import MappingUtils, csv_to_raster_workflow
  15. # 配置日志 - 避免重复日志输出
  16. logger = logging.getLogger(__name__)
  17. # 避免重复添加处理器导致重复日志
  18. if not logger.handlers:
  19. logger.setLevel(logging.INFO)
  20. handler = logging.StreamHandler()
  21. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  22. handler.setFormatter(formatter)
  23. logger.addHandler(handler)
  24. # 关闭日志传播,避免与父级日志处理器重复输出
  25. logger.propagate = False
  26. def get_base_dir():
  27. """获取基础目录路径(与土地数据处理函数一致)"""
  28. if getattr(sys, 'frozen', False):
  29. # 打包后的可执行文件
  30. return os.path.dirname(sys.executable)
  31. else:
  32. # 脚本运行模式
  33. return os.path.dirname(os.path.abspath(__file__))
  34. def get_static_dir():
  35. """获取静态资源目录"""
  36. base_dir = get_base_dir()
  37. return os.path.join(base_dir, "..", "static", "cd_flux")
  38. def get_default_boundary_shp():
  39. """获取默认边界SHP文件路径"""
  40. static_dir = get_static_dir()
  41. # 尝试几个可能的边界文件路径
  42. possible_paths = [
  43. os.path.join(static_dir, "lechang.shp"),
  44. ]
  45. for path in possible_paths:
  46. if os.path.exists(path):
  47. return path
  48. return None
  49. class FluxCdVisualizationService:
  50. """
  51. 农田Cd通量可视化服务类
  52. @description: 提供基于农田位置和Cd通量数据的空间分布图和直方图生成功能
  53. @version: 1.1.2 (修复绘图问题)
  54. """
  55. def __init__(self, db: Session = None):
  56. """
  57. 初始化可视化服务
  58. @param db: 可选的数据库会话对象
  59. """
  60. self.logger = logger
  61. self.db = db
  62. self.mapper = MappingUtils()
  63. def generate_cd_input_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  64. """
  65. 生成输入镉通量(In_Cd)的空间分布图和直方图
  66. @param output_dir: 输出文件目录
  67. @param boundary_shp: 边界Shapefile文件路径
  68. @return: 包含输出文件路径的字典
  69. """
  70. try:
  71. # 如果未提供数据库会话,则创建新的会话
  72. db = self.db if self.db else SessionLocal()
  73. should_close = self.db is None
  74. # 1. 从数据库查询数据
  75. data = self._fetch_fluxcd_data(db)
  76. if not data:
  77. if should_close:
  78. db.close()
  79. return {
  80. "success": False,
  81. "message": "数据库中未找到任何农田Cd通量数据",
  82. "data": None
  83. }
  84. # 2. 设置输出目录
  85. static_dir = get_static_dir()
  86. if output_dir is None:
  87. output_dir = static_dir
  88. os.makedirs(output_dir, exist_ok=True)
  89. # 3. 设置边界SHP(如果未提供则使用默认)
  90. if boundary_shp is None:
  91. boundary_shp = get_default_boundary_shp()
  92. if boundary_shp:
  93. self.logger.info(f"使用默认边界文件: {boundary_shp}")
  94. else:
  95. self.logger.warning("未找到默认边界文件,将不使用边界裁剪")
  96. # 4. 生成CSV文件
  97. csv_path = os.path.join(output_dir, "fluxcd_input.csv")
  98. self._generate_csv(data, csv_path)
  99. # 5. 设置模板TIFF路径(使用土地数据处理中的模板)
  100. template_tif = os.path.join(static_dir, "meanTemp.tif")
  101. if not os.path.exists(template_tif):
  102. # 尝试其他可能的模板路径
  103. template_tif = os.path.join(static_dir, "template.tif")
  104. if not os.path.exists(template_tif):
  105. raise FileNotFoundError(f"未找到模板TIFF文件")
  106. # 6. 使用csv_to_raster_workflow将CSV转换为栅格
  107. base_name = "fluxcd_input"
  108. raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
  109. # 关键修改:确保传递边界文件
  110. workflow_result = csv_to_raster_workflow(
  111. csv_file=csv_path,
  112. template_tif=template_tif,
  113. output_dir=output_dir,
  114. resolution_factor=4.0,
  115. interpolation_method='linear',
  116. field_name='In_Cd',
  117. lon_col=0, # CSV中经度列索引
  118. lat_col=1, # CSV中纬度列索引
  119. value_col=2, # CSV中数值列索引
  120. enable_interpolation=True
  121. )
  122. # 获取栅格路径和统计信息
  123. raster_path = workflow_result['raster']
  124. stats = workflow_result['statistics']
  125. # 7. 创建栅格地图 - 关键修改:使用边界文件裁剪
  126. map_path = os.path.join(output_dir, f"{base_name}_map")
  127. map_file = self.mapper.create_raster_map(
  128. shp_path=boundary_shp, # 边界文件
  129. tif_path=raster_path, # 栅格文件
  130. output_path=map_path,
  131. colormap='green_yellow_red_purple',
  132. title="input Cd flux map",
  133. output_size=12,
  134. dpi=300,
  135. enable_interpolation=False,
  136. interpolation_method='linear'
  137. )
  138. # 8. 创建直方图
  139. histogram_path = self.mapper.create_histogram(
  140. raster_path,
  141. save_path=os.path.join(output_dir, f"{base_name}_histogram.jpg"),
  142. xlabel='input Cd flux(g/ha/a)',
  143. ylabel='frequency',
  144. title='input Cd flux histogram',
  145. bins=100
  146. )
  147. result = {
  148. "success": True,
  149. "message": "成功生成Cd通量可视化结果",
  150. "data": {
  151. "csv": csv_path,
  152. "raster": raster_path,
  153. "map": map_file,
  154. "histogram": histogram_path,
  155. "statistics": stats,
  156. "boundary_used": boundary_shp if boundary_shp else "无"
  157. }
  158. }
  159. if should_close:
  160. db.close()
  161. return result
  162. except Exception as e:
  163. self.logger.error(f"生成Cd通量可视化结果时发生错误: {str(e)}", exc_info=True)
  164. return {
  165. "success": False,
  166. "message": f"生成失败: {str(e)}",
  167. "data": None
  168. }
  169. def update_from_csv(self, csv_file_path: str):
  170. """包装更新服务方法"""
  171. update_service = FluxCdUpdateService(db=self.db)
  172. return update_service.update_from_csv(csv_file_path)
  173. def _fetch_fluxcd_data(self, db: Session) -> list:
  174. """
  175. 从数据库查询需要的数据
  176. @param db: 数据库会话
  177. @returns: 查询结果列表
  178. """
  179. try:
  180. # 查询Farmland_data和FluxCd_output_data的联合数据
  181. query = db.query(
  182. FarmlandData.lon,
  183. FarmlandData.lan,
  184. FluxCdOutputData.in_cd
  185. ).join(
  186. FluxCdOutputData,
  187. (FarmlandData.farmland_id == FluxCdOutputData.farmland_id) &
  188. (FarmlandData.sample_id == FluxCdOutputData.sample_id)
  189. )
  190. return query.all()
  191. except Exception as e:
  192. self.logger.error(f"查询农田Cd通量数据时发生错误: {str(e)}")
  193. return []
  194. def _generate_csv(self, data: list, output_path: str):
  195. """
  196. 将查询结果生成CSV文件
  197. @param data: 查询结果列表
  198. @param output_path: 输出CSV文件路径
  199. """
  200. try:
  201. with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
  202. writer = csv.writer(csvfile)
  203. # 写入表头
  204. writer.writerow(['lon', 'lan', 'In_Cd'])
  205. # 写入数据
  206. for row in data:
  207. writer.writerow(row)
  208. self.logger.info(f"CSV文件已生成: {output_path}")
  209. except Exception as e:
  210. self.logger.error(f"生成CSV文件时发生错误: {str(e)}")
  211. raise
  212. class FluxCdUpdateService:
  213. """
  214. 农田Cd通量数据更新服务类
  215. 处理用户上传的CSV文件并更新数据库中的通量值
  216. """
  217. def __init__(self, db: Session = None):
  218. self.db = db if db else SessionLocal()
  219. self.logger = logging.getLogger(__name__)
  220. def update_from_csv(self, csv_file_path: str) -> Dict[str, Any]:
  221. """
  222. 从CSV文件更新Cd通量数据
  223. @param csv_file_path: CSV文件路径
  224. @return: 更新结果字典
  225. """
  226. try:
  227. # 读取CSV文件
  228. df = self._read_csv(csv_file_path)
  229. # 验证CSV格式
  230. self._validate_csv(df)
  231. # 更新数据库
  232. update_count = 0
  233. create_count = 0
  234. error_count = 0
  235. for _, row in df.iterrows():
  236. try:
  237. # 查找匹配的农田样点
  238. farmland = self._find_farmland(row['lon'], row['lan'])
  239. if farmland:
  240. # 更新现有记录
  241. input_updated = self._update_input_data(farmland, row)
  242. if input_updated:
  243. self._update_output_data(farmland)
  244. update_count += 1
  245. else:
  246. # 创建新记录
  247. if self._create_new_records(row):
  248. create_count += 1
  249. else:
  250. error_count += 1
  251. except Exception as e:
  252. self.logger.error(f"处理记录失败: {str(e)}")
  253. error_count += 1
  254. self.db.commit()
  255. return {
  256. "success": True,
  257. "message": f"成功更新 {update_count} 条记录,创建 {create_count} 条新记录,{error_count} 条失败",
  258. "updated_count": update_count,
  259. "created_count": create_count,
  260. "error_count": error_count
  261. }
  262. except Exception as e:
  263. self.db.rollback()
  264. self.logger.error(f"更新失败: {str(e)}", exc_info=True)
  265. return {
  266. "success": False,
  267. "message": f"更新失败: {str(e)}"
  268. }
  269. finally:
  270. self.db.close()
  271. def _create_new_records(self, row: pd.Series):
  272. """为新的经纬度创建完整的记录"""
  273. try:
  274. self.logger.info(f"创建新农田记录: 经度={row['lon']}, 纬度={row['lan']}")
  275. # 获取下一个可用的 Farmland_ID
  276. max_id = self.db.query(func.max(FarmlandData.farmland_id)).scalar()
  277. new_farmland_id = (max_id or 0) + 1
  278. # 创建新的农田样点记录
  279. farmland = FarmlandData(
  280. farmland_id=new_farmland_id,
  281. lon=row['lon'],
  282. lan=row['lan'],
  283. type=0.0, # 默认为旱地
  284. geom=f"POINT({row['lon']} {row['lan']})"
  285. )
  286. self.db.add(farmland)
  287. self.db.flush() # 确保ID被分配
  288. self.logger.info(f"农田记录创建成功: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  289. # 创建新的输入通量记录
  290. input_data = FluxCdInputData(
  291. farmland_id=farmland.farmland_id,
  292. sample_id=farmland.sample_id,
  293. atmospheric_deposition=row['DQCJ_Cd'],
  294. irrigation_input=row['GGS_Cd'],
  295. agro_chemicals_input=row['NCP_Cd'],
  296. # 设置合理的默认值
  297. initial_cd=0.0,
  298. groundwater_leaching=0.023,
  299. surface_runoff=0.368,
  300. grain_removal=0.0,
  301. straw_removal=0.0
  302. )
  303. self.db.add(input_data)
  304. self.logger.info("输入通量记录创建成功")
  305. # 计算输入总通量
  306. in_cd = row['DQCJ_Cd'] + row['GGS_Cd'] + row['NCP_Cd']
  307. # 计算输出总通量(假设默认值)
  308. out_cd = 0.023 + 0.368 + 0.0 + 0.0
  309. # 创建新的输出通量记录
  310. output_data = FluxCdOutputData(
  311. farmland_id=farmland.farmland_id,
  312. sample_id=farmland.sample_id,
  313. in_cd=in_cd,
  314. out_cd=out_cd,
  315. net_cd=in_cd - out_cd,
  316. end_cd=0.0 # 默认当年Cd浓度
  317. )
  318. self.db.add(output_data)
  319. self.logger.info("输出通量记录创建成功")
  320. self.logger.info(f"所有记录创建完成: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  321. return True
  322. except Exception as e:
  323. self.logger.error(f"创建新记录失败: {str(e)}")
  324. return False
  325. def _read_csv(self, file_path: str) -> pd.DataFrame:
  326. """读取CSV文件并验证基本格式"""
  327. try:
  328. df = pd.read_csv(file_path)
  329. self.logger.info(f"成功读取CSV文件: {file_path}")
  330. return df
  331. except Exception as e:
  332. raise ValueError(f"CSV文件读取失败: {str(e)}")
  333. def _validate_csv(self, df: pd.DataFrame):
  334. """验证CSV文件包含必要的列"""
  335. required_columns = {'lon', 'lan', 'DQCJ_Cd', 'GGS_Cd', 'NCP_Cd'}
  336. if not required_columns.issubset(df.columns):
  337. missing = required_columns - set(df.columns)
  338. raise ValueError(f"CSV缺少必要列: {', '.join(missing)}")
  339. def _find_farmland(self, lon: float, lan: float) -> FarmlandData:
  340. """根据经纬度查找农田样点"""
  341. # 使用容差匹配(0.001度≈100米)
  342. tol = 0.001
  343. return self.db.query(FarmlandData).filter(
  344. FarmlandData.lon.between(lon - tol, lon + tol),
  345. FarmlandData.lan.between(lan - tol, lan + tol)
  346. ).first()
  347. def _update_input_data(self, farmland: FarmlandData, row: pd.Series) -> bool:
  348. """更新输入通量数据,返回是否有更新"""
  349. input_data = self.db.query(FluxCdInputData).filter(
  350. FluxCdInputData.farmland_id == farmland.farmland_id,
  351. FluxCdInputData.sample_id == farmland.sample_id
  352. ).first()
  353. if not input_data:
  354. self.logger.warning(f"未找到输入数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  355. return False
  356. # 检查是否需要更新
  357. updated = False
  358. if input_data.atmospheric_deposition != row['DQCJ_Cd']:
  359. input_data.atmospheric_deposition = row['DQCJ_Cd']
  360. updated = True
  361. if input_data.irrigation_input != row['GGS_Cd']:
  362. input_data.irrigation_input = row['GGS_Cd']
  363. updated = True
  364. if input_data.agro_chemicals_input != row['NCP_Cd']:
  365. input_data.agro_chemicals_input = row['NCP_Cd']
  366. updated = True
  367. if updated:
  368. self.logger.info(f"更新输入通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  369. return updated
  370. def _update_output_data(self, farmland: FarmlandData):
  371. """更新输出通量数据"""
  372. output_data = self.db.query(FluxCdOutputData).filter(
  373. FluxCdOutputData.farmland_id == farmland.farmland_id,
  374. FluxCdOutputData.sample_id == farmland.sample_id
  375. ).first()
  376. if not output_data:
  377. self.logger.warning(f"未找到输出数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  378. return
  379. # 重新获取更新后的输入数据
  380. input_data = self.db.query(FluxCdInputData).filter(
  381. FluxCdInputData.farmland_id == farmland.farmland_id,
  382. FluxCdInputData.sample_id == farmland.sample_id
  383. ).first()
  384. # 重新计算并更新
  385. output_data.in_cd = input_data.input_flux()
  386. # 注意:输出总通量out_cd不会由用户上传的CSV更新,所以我们保持原值
  387. # 重新计算净通量
  388. output_data.net_cd = output_data.in_cd - output_data.out_cd
  389. self.logger.info(f"更新输出通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  390. # 测试主函数
  391. # 测试主函数
  392. if __name__ == "__main__":
  393. import tempfile
  394. import shutil
  395. import json
  396. # 创建数据库会话
  397. db = SessionLocal()
  398. try:
  399. # 初始化服务
  400. service = FluxCdVisualizationService(db=db)
  401. print("=" * 50)
  402. print("测试Cd通量可视化服务")
  403. print("=" * 50)
  404. # 测试生成可视化结果
  405. print("\n>>> 测试生成Cd输入通量可视化地图")
  406. # 测试更新服务
  407. print("\n>>> 测试从CSV更新Cd通量数据")
  408. # 创建临时目录和CSV文件
  409. temp_dir = tempfile.mkdtemp()
  410. test_csv_path = os.path.join(temp_dir, "test_update.csv")
  411. # 创建测试CSV文件
  412. test_data = [
  413. "lon,lan,DQCJ_Cd,GGS_Cd,NCP_Cd",
  414. "113.123,25.456,1.24,4.56,7.89",
  415. "113.125,25.457,2.35,5.67,8.90"
  416. ]
  417. with open(test_csv_path, 'w', encoding='utf-8') as f:
  418. f.write("\n".join(test_data))
  419. # 调用更新服务
  420. update_result = service.update_from_csv(test_csv_path)
  421. print(json.dumps(update_result, indent=2, ensure_ascii=False))
  422. # 清理临时文件
  423. shutil.rmtree(temp_dir)
  424. print(f"\n清理临时目录: {temp_dir}")
  425. print("\n测试完成!")
  426. except Exception as e:
  427. logger.error(f"测试过程中发生错误: {str(e)}")
  428. import traceback
  429. traceback.print_exc()
  430. finally:
  431. # 关闭数据库会话
  432. db.close()