123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- from fastapi import HTTPException, UploadFile
- from sqlalchemy.orm import Session
- from ..models.raster import RasterData
- import os
- from datetime import datetime
- import subprocess
- import tempfile
- import shutil
- from sqlalchemy import text
- import rasterio
- from rasterio.io import MemoryFile
- import numpy as np
- import zipfile
- from typing import List
- def get_raster_data(db: Session, raster_id: int):
- """通过ID获取一条栅格数据记录"""
- # 获取栅格数据
- query = text("""
- SELECT id, rast
- FROM raster_table
- WHERE id = :raster_id
- """)
-
- result = db.execute(query, {"raster_id": raster_id}).first()
-
- if not result:
- raise HTTPException(status_code=404, detail="栅格数据不存在")
-
- # 将Row对象转换为字典
- return dict(result._mapping)
- async def import_raster_data(file: UploadFile, db: Session) -> dict:
- """导入栅格数据到数据库"""
- try:
- # 创建临时目录
- temp_dir = tempfile.mkdtemp()
- temp_file_path = os.path.join(temp_dir, file.filename)
-
- try:
- # 保存上传的文件
- with open(temp_file_path, "wb") as buffer:
- content = await file.read()
- buffer.write(content)
-
- # 检查文件是否存在且不为空
- if not os.path.exists(temp_file_path):
- raise Exception("临时文件创建失败")
- if os.path.getsize(temp_file_path) == 0:
- raise Exception("上传的文件为空")
-
- # 使用raster2pgsql命令导入数据
- cmd = [
- 'raster2pgsql',
- '-s', '4490', # 空间参考系统
- '-I', # 创建空间索引
- '-M', # 创建空间索引
- '-a', # 追加模式
- temp_file_path,
- 'raster_table'
- ]
-
- # 执行命令
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE
- )
- stdout, stderr = process.communicate()
-
- if process.returncode != 0:
- error_msg = stderr.decode()
- raise Exception(f"raster2pgsql命令执行失败: {error_msg}")
-
- # 检查导入的SQL是否为空
- if not stdout:
- raise Exception("raster2pgsql没有生成任何SQL语句")
-
- # 执行生成的SQL
- sql_commands = stdout.decode().split(';')
- for sql in sql_commands:
- if sql.strip():
- db.execute(text(sql))
- db.commit()
-
- # 获取最后插入的记录的ID
- result = db.execute(text("""
- SELECT id, ST_IsEmpty(rast) as is_empty, ST_Width(rast) as width
- FROM raster_table
- ORDER BY id DESC
- LIMIT 1
- """)).first()
-
- if not result:
- raise Exception("无法获取导入的栅格数据ID")
-
- if result.is_empty or result.width is None:
- raise Exception("导入的栅格数据为空或无效")
-
- return {
- "message": "栅格数据导入成功",
- "raster_id": result.id,
- "file_path": temp_file_path
- }
-
- finally:
- # 清理临时目录
- shutil.rmtree(temp_dir)
-
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"导入失败: {str(e)}")
- def export_raster_data(db: Session, raster_id: int):
- """导出指定ID的栅格数据为TIFF文件"""
- try:
- # 创建临时目录
- temp_dir = tempfile.mkdtemp()
- temp_file_path = os.path.join(temp_dir, f"raster_{raster_id}.tif")
-
- # 从数据库获取栅格数据
- query = text("""
- SELECT
- ST_AsBinary(rast) as raster_data,
- ST_Width(rast) as width,
- ST_Height(rast) as height,
- ST_NumBands(rast) as num_bands,
- ST_UpperLeftX(rast) as upper_left_x,
- ST_UpperLeftY(rast) as upper_left_y,
- ST_ScaleX(rast) as scale_x,
- ST_ScaleY(rast) as scale_y,
- ST_SkewX(rast) as skew_x,
- ST_SkewY(rast) as skew_y,
- ST_SRID(rast) as srid,
- ST_BandPixelType(rast, 1) as pixel_type
- FROM raster_table
- WHERE id = :raster_id
- """)
-
- result = db.execute(query, {"raster_id": raster_id}).first()
-
- if not result:
- shutil.rmtree(temp_dir)
- raise HTTPException(status_code=404, detail="栅格数据不存在")
-
- # 根据像素类型选择合适的数据类型
- dtype_map = {
- '8BUI': np.uint8,
- '16BUI': np.uint16,
- '32BUI': np.uint32,
- '8BSI': np.int8,
- '16BSI': np.int16,
- '32BSI': np.int32,
- '32BF': np.float32,
- '64BF': np.float64
- }
-
- dtype = dtype_map.get(result.pixel_type, np.float32)
-
- # 计算预期的数据大小
- expected_size = result.width * result.height * result.num_bands * np.dtype(dtype).itemsize
-
- # 检查二进制数据大小
- if len(result.raster_data) < expected_size:
- shutil.rmtree(temp_dir)
- raise Exception(f"数据大小不足: 预期至少 {expected_size} 字节,实际 {len(result.raster_data)} 字节")
-
- # 跳过头部信息,只取实际数据部分
- # PostGIS 二进制格式的头部通常是66字节
- header_size = 66
- actual_data = result.raster_data[header_size:header_size + expected_size]
-
- # 将二进制数据转换为numpy数组
- raster_data = np.frombuffer(actual_data, dtype=dtype)
- # 重塑数组为正确的形状 (bands, height, width)
- raster_data = raster_data.reshape((result.num_bands, result.height, result.width))
-
- # 创建内存文件
- with MemoryFile() as memfile:
- # 创建新的栅格数据集
- with memfile.open(
- driver='GTiff',
- height=result.height,
- width=result.width,
- count=result.num_bands,
- dtype=dtype,
- crs=f'EPSG:{result.srid}',
- transform=rasterio.transform.from_origin(
- result.upper_left_x,
- result.upper_left_y,
- result.scale_x,
- result.scale_y
- )
- ) as dataset:
- # 写入数据
- dataset.write(raster_data)
-
- # 将内存文件写入磁盘
- with open(temp_file_path, 'wb') as f:
- f.write(memfile.read())
-
- # 检查文件是否成功创建
- if not os.path.exists(temp_file_path):
- shutil.rmtree(temp_dir)
- raise Exception("文件创建失败")
-
- return {
- "message": "栅格数据导出成功",
- "file_path": temp_file_path,
- "temp_dir": temp_dir # 返回临时目录路径,以便后续清理
- }
-
- except Exception as e:
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
- raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
- def export_raster_data_batch(db: Session, raster_ids: List[int]):
- """批量导出栅格数据为TIFF文件"""
- try:
- # 创建临时目录
- temp_dir = tempfile.mkdtemp()
- zip_file_path = os.path.join(temp_dir, "raster_batch.zip")
-
- # 创建ZIP文件
- with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
- for raster_id in raster_ids:
- try:
- # 导出单个栅格数据
- result = export_raster_data(db, raster_id)
- if result and os.path.exists(result["file_path"]):
- # 将文件添加到ZIP中
- zipf.write(
- result["file_path"],
- arcname=f"raster_{raster_id}.tif"
- )
- except Exception as e:
- print(f"导出栅格数据 {raster_id} 失败: {str(e)}")
- continue
-
- # 检查ZIP文件是否成功创建
- if not os.path.exists(zip_file_path):
- shutil.rmtree(temp_dir)
- raise Exception("批量导出文件创建失败")
-
- return {
- "message": "栅格数据批量导出成功",
- "file_path": zip_file_path,
- "temp_dir": temp_dir
- }
-
- except Exception as e:
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
- raise HTTPException(status_code=500, detail=f"批量导出失败: {str(e)}")
|