123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- import os
- import csv
- import logging
- import sys
- import pandas as pd
- import geopandas as gpd
- from shapely.geometry import Point
- from pathlib import Path
- from typing import Dict, Any
- from sqlalchemy import func
- from sqlalchemy.orm import Session
- from app.database import SessionLocal
- from app.models import FarmlandData, FluxCdOutputData, FluxCdInputData
- from app.utils.mapping_utils import MappingUtils, csv_to_raster_workflow
- # 配置日志 - 避免重复日志输出
- logger = logging.getLogger(__name__)
- # 避免重复添加处理器导致重复日志
- if not logger.handlers:
- logger.setLevel(logging.INFO)
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
- # 关闭日志传播,避免与父级日志处理器重复输出
- logger.propagate = False
- def get_base_dir():
- """获取基础目录路径(与土地数据处理函数一致)"""
- if getattr(sys, 'frozen', False):
- # 打包后的可执行文件
- return os.path.dirname(sys.executable)
- else:
- # 脚本运行模式
- return os.path.dirname(os.path.abspath(__file__))
- def get_static_dir():
- """获取静态资源目录"""
- base_dir = get_base_dir()
- return os.path.join(base_dir, "..", "static", "cd_flux")
- def get_default_boundary_shp():
- """获取默认边界SHP文件路径"""
- static_dir = get_static_dir()
- # 尝试几个可能的边界文件路径
- possible_paths = [
- os.path.join(static_dir, "lechang.shp"),
- ]
- for path in possible_paths:
- if os.path.exists(path):
- return path
- return None
- class FluxCdVisualizationService:
- """
- 农田Cd通量可视化服务类
- @description: 提供基于农田位置和Cd通量数据的空间分布图和直方图生成功能
- @version: 1.1.2 (修复绘图问题)
- """
- def __init__(self, db: Session = None):
- """
- 初始化可视化服务
- @param db: 可选的数据库会话对象
- """
- self.logger = logger
- self.db = db
- self.mapper = MappingUtils()
- def generate_cd_input_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
- """
- 生成输入镉通量(In_Cd)的空间分布图和直方图
- @param output_dir: 输出文件目录
- @param boundary_shp: 边界Shapefile文件路径
- @return: 包含输出文件路径的字典
- """
- try:
- # 如果未提供数据库会话,则创建新的会话
- db = self.db if self.db else SessionLocal()
- should_close = self.db is None
- # 1. 从数据库查询数据
- data = self._fetch_fluxcd_data(db)
- if not data:
- if should_close:
- db.close()
- return {
- "success": False,
- "message": "数据库中未找到任何农田Cd通量数据",
- "data": None
- }
- # 2. 设置输出目录
- static_dir = get_static_dir()
- if output_dir is None:
- output_dir = static_dir
- os.makedirs(output_dir, exist_ok=True)
- # 3. 设置边界SHP(如果未提供则使用默认)
- if boundary_shp is None:
- boundary_shp = get_default_boundary_shp()
- if boundary_shp:
- self.logger.info(f"使用默认边界文件: {boundary_shp}")
- else:
- self.logger.warning("未找到默认边界文件,将不使用边界裁剪")
- # 4. 生成CSV文件
- csv_path = os.path.join(output_dir, "fluxcd_input.csv")
- self._generate_csv(data, csv_path)
- # 5. 设置模板TIFF路径(使用土地数据处理中的模板)
- template_tif = os.path.join(static_dir, "meanTemp.tif")
- if not os.path.exists(template_tif):
- # 尝试其他可能的模板路径
- template_tif = os.path.join(static_dir, "template.tif")
- if not os.path.exists(template_tif):
- raise FileNotFoundError(f"未找到模板TIFF文件")
- # 6. 使用csv_to_raster_workflow将CSV转换为栅格
- base_name = "fluxcd_input"
- raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
- # 关键修改:确保传递边界文件
- workflow_result = csv_to_raster_workflow(
- csv_file=csv_path,
- template_tif=template_tif,
- output_dir=output_dir,
- resolution_factor=4.0,
- interpolation_method='linear',
- field_name='In_Cd',
- lon_col=0, # CSV中经度列索引
- lat_col=1, # CSV中纬度列索引
- value_col=2, # CSV中数值列索引
- enable_interpolation=True
- )
- # 获取栅格路径和统计信息
- raster_path = workflow_result['raster']
- stats = workflow_result['statistics']
- # 7. 创建栅格地图 - 关键修改:使用边界文件裁剪
- map_path = os.path.join(output_dir, f"{base_name}_map")
- map_file = self.mapper.create_raster_map(
- shp_path=boundary_shp, # 边界文件
- tif_path=raster_path, # 栅格文件
- output_path=map_path,
- colormap='green_yellow_red_purple',
- title="input Cd flux map",
- output_size=12,
- dpi=300,
- enable_interpolation=False,
- interpolation_method='linear'
- )
- # 8. 创建直方图
- histogram_path = self.mapper.create_histogram(
- raster_path,
- save_path=os.path.join(output_dir, f"{base_name}_histogram.jpg"),
- xlabel='input Cd flux(g/ha/a)',
- ylabel='frequency',
- title='input Cd flux histogram',
- bins=100
- )
- result = {
- "success": True,
- "message": "成功生成Cd通量可视化结果",
- "data": {
- "csv": csv_path,
- "raster": raster_path,
- "map": map_file,
- "histogram": histogram_path,
- "statistics": stats,
- "boundary_used": boundary_shp if boundary_shp else "无"
- }
- }
- if should_close:
- db.close()
- return result
- except Exception as e:
- self.logger.error(f"生成Cd通量可视化结果时发生错误: {str(e)}", exc_info=True)
- return {
- "success": False,
- "message": f"生成失败: {str(e)}",
- "data": None
- }
- def update_from_csv(self, csv_file_path: str):
- """包装更新服务方法"""
- update_service = FluxCdUpdateService(db=self.db)
- return update_service.update_from_csv(csv_file_path)
- def _fetch_fluxcd_data(self, db: Session) -> list:
- """
- 从数据库查询需要的数据
- @param db: 数据库会话
- @returns: 查询结果列表
- """
- try:
- # 查询Farmland_data和FluxCd_output_data的联合数据
- query = db.query(
- FarmlandData.lon,
- FarmlandData.lan,
- FluxCdOutputData.in_cd
- ).join(
- FluxCdOutputData,
- (FarmlandData.farmland_id == FluxCdOutputData.farmland_id) &
- (FarmlandData.sample_id == FluxCdOutputData.sample_id)
- )
- return query.all()
- except Exception as e:
- self.logger.error(f"查询农田Cd通量数据时发生错误: {str(e)}")
- return []
- def _generate_csv(self, data: list, output_path: str):
- """
- 将查询结果生成CSV文件
- @param data: 查询结果列表
- @param output_path: 输出CSV文件路径
- """
- try:
- with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
- writer = csv.writer(csvfile)
- # 写入表头
- writer.writerow(['lon', 'lan', 'In_Cd'])
- # 写入数据
- for row in data:
- writer.writerow(row)
- self.logger.info(f"CSV文件已生成: {output_path}")
- except Exception as e:
- self.logger.error(f"生成CSV文件时发生错误: {str(e)}")
- raise
- class FluxCdUpdateService:
- """
- 农田Cd通量数据更新服务类
- 处理用户上传的CSV文件并更新数据库中的通量值
- """
- def __init__(self, db: Session = None):
- self.db = db if db else SessionLocal()
- self.logger = logging.getLogger(__name__)
- def update_from_csv(self, csv_file_path: str) -> Dict[str, Any]:
- """
- 从CSV文件更新Cd通量数据
- @param csv_file_path: CSV文件路径
- @return: 更新结果字典
- """
- try:
- # 读取CSV文件
- df = self._read_csv(csv_file_path)
- # 验证CSV格式
- self._validate_csv(df)
- # 更新数据库
- update_count = 0
- create_count = 0
- error_count = 0
- for _, row in df.iterrows():
- try:
- # 查找匹配的农田样点
- farmland = self._find_farmland(row['lon'], row['lan'])
- if farmland:
- # 更新现有记录
- input_updated = self._update_input_data(farmland, row)
- if input_updated:
- self._update_output_data(farmland)
- update_count += 1
- else:
- # 创建新记录
- if self._create_new_records(row):
- create_count += 1
- else:
- error_count += 1
- except Exception as e:
- self.logger.error(f"处理记录失败: {str(e)}")
- error_count += 1
- self.db.commit()
- return {
- "success": True,
- "message": f"成功更新 {update_count} 条记录,创建 {create_count} 条新记录,{error_count} 条失败",
- "updated_count": update_count,
- "created_count": create_count,
- "error_count": error_count
- }
- except Exception as e:
- self.db.rollback()
- self.logger.error(f"更新失败: {str(e)}", exc_info=True)
- return {
- "success": False,
- "message": f"更新失败: {str(e)}"
- }
- finally:
- self.db.close()
- def _create_new_records(self, row: pd.Series):
- """为新的经纬度创建完整的记录"""
- try:
- self.logger.info(f"创建新农田记录: 经度={row['lon']}, 纬度={row['lan']}")
- # 获取下一个可用的 Farmland_ID
- max_id = self.db.query(func.max(FarmlandData.farmland_id)).scalar()
- new_farmland_id = (max_id or 0) + 1
- # 创建新的农田样点记录
- farmland = FarmlandData(
- farmland_id=new_farmland_id,
- lon=row['lon'],
- lan=row['lan'],
- type=0.0, # 默认为旱地
- geom=f"POINT({row['lon']} {row['lan']})"
- )
- self.db.add(farmland)
- self.db.flush() # 确保ID被分配
- self.logger.info(f"农田记录创建成功: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- # 创建新的输入通量记录
- input_data = FluxCdInputData(
- farmland_id=farmland.farmland_id,
- sample_id=farmland.sample_id,
- atmospheric_deposition=row['DQCJ_Cd'],
- irrigation_input=row['GGS_Cd'],
- agro_chemicals_input=row['NCP_Cd'],
- # 设置合理的默认值
- initial_cd=0.0,
- groundwater_leaching=0.023,
- surface_runoff=0.368,
- grain_removal=0.0,
- straw_removal=0.0
- )
- self.db.add(input_data)
- self.logger.info("输入通量记录创建成功")
- # 计算输入总通量
- in_cd = row['DQCJ_Cd'] + row['GGS_Cd'] + row['NCP_Cd']
- # 计算输出总通量(假设默认值)
- out_cd = 0.023 + 0.368 + 0.0 + 0.0
- # 创建新的输出通量记录
- output_data = FluxCdOutputData(
- farmland_id=farmland.farmland_id,
- sample_id=farmland.sample_id,
- in_cd=in_cd,
- out_cd=out_cd,
- net_cd=in_cd - out_cd,
- end_cd=0.0 # 默认当年Cd浓度
- )
- self.db.add(output_data)
- self.logger.info("输出通量记录创建成功")
- self.logger.info(f"所有记录创建完成: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- return True
- except Exception as e:
- self.logger.error(f"创建新记录失败: {str(e)}")
- return False
- def _read_csv(self, file_path: str) -> pd.DataFrame:
- """读取CSV文件并验证基本格式"""
- try:
- df = pd.read_csv(file_path)
- self.logger.info(f"成功读取CSV文件: {file_path}")
- return df
- except Exception as e:
- raise ValueError(f"CSV文件读取失败: {str(e)}")
- def _validate_csv(self, df: pd.DataFrame):
- """验证CSV文件包含必要的列"""
- required_columns = {'lon', 'lan', 'DQCJ_Cd', 'GGS_Cd', 'NCP_Cd'}
- if not required_columns.issubset(df.columns):
- missing = required_columns - set(df.columns)
- raise ValueError(f"CSV缺少必要列: {', '.join(missing)}")
- def _find_farmland(self, lon: float, lan: float) -> FarmlandData:
- """根据经纬度查找农田样点"""
- # 使用容差匹配(0.001度≈100米)
- tol = 0.001
- return self.db.query(FarmlandData).filter(
- FarmlandData.lon.between(lon - tol, lon + tol),
- FarmlandData.lan.between(lan - tol, lan + tol)
- ).first()
- def _update_input_data(self, farmland: FarmlandData, row: pd.Series) -> bool:
- """更新输入通量数据,返回是否有更新"""
- input_data = self.db.query(FluxCdInputData).filter(
- FluxCdInputData.farmland_id == farmland.farmland_id,
- FluxCdInputData.sample_id == farmland.sample_id
- ).first()
- if not input_data:
- self.logger.warning(f"未找到输入数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- return False
- # 检查是否需要更新
- updated = False
- if input_data.atmospheric_deposition != row['DQCJ_Cd']:
- input_data.atmospheric_deposition = row['DQCJ_Cd']
- updated = True
- if input_data.irrigation_input != row['GGS_Cd']:
- input_data.irrigation_input = row['GGS_Cd']
- updated = True
- if input_data.agro_chemicals_input != row['NCP_Cd']:
- input_data.agro_chemicals_input = row['NCP_Cd']
- updated = True
- if updated:
- self.logger.info(f"更新输入通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- return updated
- def _update_output_data(self, farmland: FarmlandData):
- """更新输出通量数据"""
- output_data = self.db.query(FluxCdOutputData).filter(
- FluxCdOutputData.farmland_id == farmland.farmland_id,
- FluxCdOutputData.sample_id == farmland.sample_id
- ).first()
- if not output_data:
- self.logger.warning(f"未找到输出数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- return
- # 重新获取更新后的输入数据
- input_data = self.db.query(FluxCdInputData).filter(
- FluxCdInputData.farmland_id == farmland.farmland_id,
- FluxCdInputData.sample_id == farmland.sample_id
- ).first()
- # 重新计算并更新
- output_data.in_cd = input_data.input_flux()
- # 注意:输出总通量out_cd不会由用户上传的CSV更新,所以我们保持原值
- # 重新计算净通量
- output_data.net_cd = output_data.in_cd - output_data.out_cd
- self.logger.info(f"更新输出通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
- # 测试主函数
- # 测试主函数
- if __name__ == "__main__":
- import tempfile
- import shutil
- import json
- # 创建数据库会话
- db = SessionLocal()
- try:
- # 初始化服务
- service = FluxCdVisualizationService(db=db)
- print("=" * 50)
- print("测试Cd通量可视化服务")
- print("=" * 50)
- # 测试生成可视化结果
- print("\n>>> 测试生成Cd输入通量可视化地图")
- # 测试更新服务
- print("\n>>> 测试从CSV更新Cd通量数据")
- # 创建临时目录和CSV文件
- temp_dir = tempfile.mkdtemp()
- test_csv_path = os.path.join(temp_dir, "test_update.csv")
- # 创建测试CSV文件
- test_data = [
- "lon,lan,DQCJ_Cd,GGS_Cd,NCP_Cd",
- "113.123,25.456,1.24,4.56,7.89",
- "113.125,25.457,2.35,5.67,8.90"
- ]
- with open(test_csv_path, 'w', encoding='utf-8') as f:
- f.write("\n".join(test_data))
- # 调用更新服务
- update_result = service.update_from_csv(test_csv_path)
- print(json.dumps(update_result, indent=2, ensure_ascii=False))
- # 清理临时文件
- shutil.rmtree(temp_dir)
- print(f"\n清理临时目录: {temp_dir}")
- print("\n测试完成!")
- except Exception as e:
- logger.error(f"测试过程中发生错误: {str(e)}")
- import traceback
- traceback.print_exc()
- finally:
- # 关闭数据库会话
- db.close()
|