from fastapi import HTTPException, UploadFile from sqlalchemy.orm import Session from ..models.raster import RasterData import os from datetime import datetime import subprocess import tempfile import shutil from sqlalchemy import text import rasterio from rasterio.io import MemoryFile import numpy as np import zipfile from typing import List def get_raster_data(db: Session, raster_id: int): """通过ID获取一条栅格数据记录""" # 获取栅格数据 query = text(""" SELECT id, rast FROM raster_table WHERE id = :raster_id """) result = db.execute(query, {"raster_id": raster_id}).first() if not result: raise HTTPException(status_code=404, detail="栅格数据不存在") # 将Row对象转换为字典 return dict(result._mapping) async def import_raster_data(file: UploadFile, db: Session) -> dict: """导入栅格数据到数据库""" try: # 创建临时目录 temp_dir = tempfile.mkdtemp() temp_file_path = os.path.join(temp_dir, file.filename) try: # 保存上传的文件 with open(temp_file_path, "wb") as buffer: content = await file.read() buffer.write(content) # 检查文件是否存在且不为空 if not os.path.exists(temp_file_path): raise Exception("临时文件创建失败") if os.path.getsize(temp_file_path) == 0: raise Exception("上传的文件为空") # 使用raster2pgsql命令导入数据 cmd = [ 'raster2pgsql', '-s', '4490', # 空间参考系统 '-I', # 创建空间索引 '-M', # 创建空间索引 '-a', # 追加模式 temp_file_path, 'raster_table' ] # 执行命令 process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = process.communicate() if process.returncode != 0: error_msg = stderr.decode() raise Exception(f"raster2pgsql命令执行失败: {error_msg}") # 检查导入的SQL是否为空 if not stdout: raise Exception("raster2pgsql没有生成任何SQL语句") # 执行生成的SQL sql_commands = stdout.decode().split(';') for sql in sql_commands: if sql.strip(): db.execute(text(sql)) db.commit() # 获取最后插入的记录的ID result = db.execute(text(""" SELECT id, ST_IsEmpty(rast) as is_empty, ST_Width(rast) as width FROM raster_table ORDER BY id DESC LIMIT 1 """)).first() if not result: raise Exception("无法获取导入的栅格数据ID") if result.is_empty or result.width is None: raise Exception("导入的栅格数据为空或无效") return { "message": "栅格数据导入成功", "raster_id": result.id, "file_path": temp_file_path } finally: # 清理临时目录 shutil.rmtree(temp_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}") def export_raster_data(db: Session, raster_id: int): """导出指定ID的栅格数据为TIFF文件""" try: # 创建临时目录 temp_dir = tempfile.mkdtemp() temp_file_path = os.path.join(temp_dir, f"raster_{raster_id}.tif") # 从数据库获取栅格数据 query = text(""" SELECT ST_AsBinary(rast) as raster_data, ST_Width(rast) as width, ST_Height(rast) as height, ST_NumBands(rast) as num_bands, ST_UpperLeftX(rast) as upper_left_x, ST_UpperLeftY(rast) as upper_left_y, ST_ScaleX(rast) as scale_x, ST_ScaleY(rast) as scale_y, ST_SkewX(rast) as skew_x, ST_SkewY(rast) as skew_y, ST_SRID(rast) as srid, ST_BandPixelType(rast, 1) as pixel_type FROM raster_table WHERE id = :raster_id """) result = db.execute(query, {"raster_id": raster_id}).first() if not result: shutil.rmtree(temp_dir) raise HTTPException(status_code=404, detail="栅格数据不存在") # 根据像素类型选择合适的数据类型 dtype_map = { '8BUI': np.uint8, '16BUI': np.uint16, '32BUI': np.uint32, '8BSI': np.int8, '16BSI': np.int16, '32BSI': np.int32, '32BF': np.float32, '64BF': np.float64 } dtype = dtype_map.get(result.pixel_type, np.float32) # 计算预期的数据大小 expected_size = result.width * result.height * result.num_bands * np.dtype(dtype).itemsize # 检查二进制数据大小 if len(result.raster_data) < expected_size: shutil.rmtree(temp_dir) raise Exception(f"数据大小不足: 预期至少 {expected_size} 字节,实际 {len(result.raster_data)} 字节") # 跳过头部信息,只取实际数据部分 # PostGIS 二进制格式的头部通常是66字节 header_size = 66 actual_data = result.raster_data[header_size:header_size + expected_size] # 将二进制数据转换为numpy数组 raster_data = np.frombuffer(actual_data, dtype=dtype) # 重塑数组为正确的形状 (bands, height, width) raster_data = raster_data.reshape((result.num_bands, result.height, result.width)) # 创建内存文件 with MemoryFile() as memfile: # 创建新的栅格数据集 with memfile.open( driver='GTiff', height=result.height, width=result.width, count=result.num_bands, dtype=dtype, crs=f'EPSG:{result.srid}', transform=rasterio.transform.from_origin( result.upper_left_x, result.upper_left_y, result.scale_x, result.scale_y ) ) as dataset: # 写入数据 dataset.write(raster_data) # 将内存文件写入磁盘 with open(temp_file_path, 'wb') as f: f.write(memfile.read()) # 检查文件是否成功创建 if not os.path.exists(temp_file_path): shutil.rmtree(temp_dir) raise Exception("文件创建失败") return { "message": "栅格数据导出成功", "file_path": temp_file_path, "temp_dir": temp_dir # 返回临时目录路径,以便后续清理 } except Exception as e: if os.path.exists(temp_dir): shutil.rmtree(temp_dir) raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}") def export_raster_data_batch(db: Session, raster_ids: List[int]): """批量导出栅格数据为TIFF文件""" try: # 创建临时目录 temp_dir = tempfile.mkdtemp() zip_file_path = os.path.join(temp_dir, "raster_batch.zip") # 创建ZIP文件 with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for raster_id in raster_ids: try: # 导出单个栅格数据 result = export_raster_data(db, raster_id) if result and os.path.exists(result["file_path"]): # 将文件添加到ZIP中 zipf.write( result["file_path"], arcname=f"raster_{raster_id}.tif" ) except Exception as e: print(f"导出栅格数据 {raster_id} 失败: {str(e)}") continue # 检查ZIP文件是否成功创建 if not os.path.exists(zip_file_path): shutil.rmtree(temp_dir) raise Exception("批量导出文件创建失败") return { "message": "栅格数据批量导出成功", "file_path": zip_file_path, "temp_dir": temp_dir } except Exception as e: if os.path.exists(temp_dir): shutil.rmtree(temp_dir) raise HTTPException(status_code=500, detail=f"批量导出失败: {str(e)}")