Browse Source

添加省市模型及边界服务,支持通过名称获取边界GeoJSON,更新配置以允许环境变量覆盖边界文件路径。

drggboy 1 week ago
parent
commit
465cc0f1fc

+ 2 - 1
Cd_Prediction_Integrated_System/config.py

@@ -59,7 +59,8 @@ DATA_PATHS = {
 # 分析模块配置
 # 分析模块配置
 ANALYSIS_CONFIG = {
 ANALYSIS_CONFIG = {
     "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
     "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
-    "boundary_shp": os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp"),
+    # 允许通过环境变量 CD_BOUNDARY_FILE 覆盖边界文件(支持 .geojson/.shp)
+    "boundary_shp": os.environ.get('CD_BOUNDARY_FILE', os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp")),
     "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
     "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
     "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
     "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
 }
 }

+ 10 - 1
app/api/vector.py

@@ -1,9 +1,10 @@
 # 矢量数据API
 # 矢量数据API
-from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
 from fastapi.responses import FileResponse
 from fastapi.responses import FileResponse
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 from ..database import get_db
 from ..database import get_db
 from ..services import vector_service
 from ..services import vector_service
+from ..services.admin_boundary_service import get_boundary_geojson_by_name
 import os
 import os
 import shutil
 import shutil
 from typing import List
 from typing import List
@@ -19,6 +20,14 @@ async def import_vector_data(file: UploadFile = File(...), db: Session = Depends
     
     
     return await vector_service.import_vector_data(file, db)
     return await vector_service.import_vector_data(file, db)
 
 
+@router.get("/boundary", summary="按名称获取边界", description="按县/市/省名称获取边界GeoJSON")
+def get_boundary(name: str = Query(...), level: str = Query("auto"), db: Session = Depends(get_db)):
+    """按名称和层级获取边界。level: county/city/province/auto"""
+    try:
+        return get_boundary_geojson_by_name(db, name, level)
+    except Exception as e:
+        raise HTTPException(status_code=404, detail=str(e))
+
 @router.get("/{vector_id}", summary="获取矢量数据", description="根据ID获取一条矢量数据记录")
 @router.get("/{vector_id}", summary="获取矢量数据", description="根据ID获取一条矢量数据记录")
 def get_vector_data(vector_id: int, db: Session = Depends(get_db)):
 def get_vector_data(vector_id: int, db: Session = Depends(get_db)):
     """获取指定ID的矢量数据"""
     """获取指定ID的矢量数据"""

+ 2 - 0
app/models/__init__.py

@@ -19,3 +19,5 @@ from app.models.atmo_company import *
 from app.models.water_sample import *
 from app.models.water_sample import *
 from app.models.agricultural import *
 from app.models.agricultural import *
 from app.models.cross_section import *
 from app.models.cross_section import *
+from app.models.province import *
+from app.models.city import *

+ 92 - 0
app/models/city.py

@@ -0,0 +1,92 @@
+from sqlalchemy import Column, Integer, String, Text, Index, JSON
+from sqlalchemy.dialects.postgresql import JSONB
+from geoalchemy2 import Geometry
+from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
+import json
+
+class City(Base):
+    """市级行政区划地理数据表"""
+    __tablename__ = "cities"
+
+    # 字段映射字典
+    FIELD_MAPPING = {
+        '市名': 'name',
+        '市代码': 'code',
+        '省名': 'province_name',
+        '省代码': 'province_code'
+    }
+
+    # 反向映射字典
+    REVERSE_FIELD_MAPPING = {v: k for k, v in FIELD_MAPPING.items()}
+
+    id = Column(Integer, primary_key=True, index=True)
+    name = Column(String(100), nullable=False, comment="市名")
+    code = Column(Integer, nullable=False, unique=True, comment="市代码")
+    province_name = Column(String(100), nullable=False, comment="省名") 
+    province_code = Column(Integer, nullable=False, comment="省代码")
+    
+    # 使用PostGIS的几何类型来存储多边形数据
+    geometry = Column(Geometry('MULTIPOLYGON', srid=4326), nullable=False, comment="市级行政区划的几何数据")
+    
+    # 存储完整的GeoJSON数据
+    geojson = Column(JSON, nullable=False, comment="完整的GeoJSON数据") 
+
+    # 显式定义索引
+    __table_args__ = (
+        Index('idx_cities_name', 'name'),
+        Index('idx_cities_code', 'code'),
+        Index('idx_cities_province_code', 'province_code'),
+        #Index('idx_cities_geometry', 'geometry', postgresql_using='gist'),
+    ) 
+
+    @classmethod
+    def from_geojson_feature(cls, feature):
+        """从GeoJSON Feature创建City实例
+        
+        Args:
+            feature: GeoJSON Feature对象
+            
+        Returns:
+            City: 新创建的City实例
+        """
+        properties = feature['properties']
+        # 转换中文字段名为英文
+        mapped_properties = {
+            cls.FIELD_MAPPING.get(k, k): v 
+            for k, v in properties.items()
+        }
+        
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
+        geometry = feature['geometry']
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
+        
+        # 创建实例
+        city = cls(
+            **mapped_properties,
+            geometry=wkt,
+            geojson=feature
+        )
+        return city
+
+    def to_geojson_feature(self):
+        """将City实例转换为GeoJSON Feature
+        
+        Returns:
+            dict: GeoJSON Feature对象
+        """
+        properties = {}
+        # 转换英文字段名为中文
+        for eng_field, cn_field in self.REVERSE_FIELD_MAPPING.items():
+            if hasattr(self, eng_field):
+                properties[cn_field] = getattr(self, eng_field)
+
+        return {
+            'type': 'Feature',
+            'properties': properties,
+            'geometry': self.geometry
+        }

+ 7 - 2
app/models/county.py

@@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer, String, Text, Index, JSON
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import JSONB
 from geoalchemy2 import Geometry
 from geoalchemy2 import Geometry
 from app.database import Base
 from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
 import json
 import json
 
 
 class County(Base):
 class County(Base):
@@ -59,9 +60,13 @@ class County(Base):
             for k, v in properties.items()
             for k, v in properties.items()
         }
         }
         
         
-        # 将GeoJSON geometry转换为WKT格式
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
         geometry = feature['geometry']
         geometry = feature['geometry']
-        wkt = f"SRID=4326;{json.dumps(geometry)}"
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
         
         
         # 创建实例
         # 创建实例
         county = cls(
         county = cls(

+ 87 - 0
app/models/province.py

@@ -0,0 +1,87 @@
+from sqlalchemy import Column, Integer, String, Text, Index, JSON
+from sqlalchemy.dialects.postgresql import JSONB
+from geoalchemy2 import Geometry
+from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
+import json
+
+class Province(Base):
+    """省级行政区划地理数据表"""
+    __tablename__ = "provinces"
+
+    # 字段映射字典
+    FIELD_MAPPING = {
+        '省名': 'name',
+        '省代码': 'code'
+    }
+
+    # 反向映射字典
+    REVERSE_FIELD_MAPPING = {v: k for k, v in FIELD_MAPPING.items()}
+
+    id = Column(Integer, primary_key=True, index=True)
+    name = Column(String(100), nullable=False, comment="省名")
+    code = Column(Integer, nullable=False, unique=True, comment="省代码")
+    
+    # 使用PostGIS的几何类型来存储多边形数据
+    geometry = Column(Geometry('MULTIPOLYGON', srid=4326), nullable=False, comment="省级行政区划的几何数据")
+    
+    # 存储完整的GeoJSON数据
+    geojson = Column(JSON, nullable=False, comment="完整的GeoJSON数据") 
+
+    # 显式定义索引
+    __table_args__ = (
+        Index('idx_provinces_name', 'name'),
+        Index('idx_provinces_code', 'code'),
+        #Index('idx_provinces_geometry', 'geometry', postgresql_using='gist'),
+    ) 
+
+    @classmethod
+    def from_geojson_feature(cls, feature):
+        """从GeoJSON Feature创建Province实例
+        
+        Args:
+            feature: GeoJSON Feature对象
+            
+        Returns:
+            Province: 新创建的Province实例
+        """
+        properties = feature['properties']
+        # 转换中文字段名为英文
+        mapped_properties = {
+            cls.FIELD_MAPPING.get(k, k): v 
+            for k, v in properties.items()
+        }
+        
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
+        geometry = feature['geometry']
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
+        
+        # 创建实例
+        province = cls(
+            **mapped_properties,
+            geometry=wkt,
+            geojson=feature
+        )
+        return province
+
+    def to_geojson_feature(self):
+        """将Province实例转换为GeoJSON Feature
+        
+        Returns:
+            dict: GeoJSON Feature对象
+        """
+        properties = {}
+        # 转换英文字段名为中文
+        for eng_field, cn_field in self.REVERSE_FIELD_MAPPING.items():
+            if hasattr(self, eng_field):
+                properties[cn_field] = getattr(self, eng_field)
+
+        return {
+            'type': 'Feature',
+            'properties': properties,
+            'geometry': self.geometry
+        }

+ 114 - 0
app/services/admin_boundary_service.py

@@ -0,0 +1,114 @@
+from sqlalchemy.orm import Session
+from sqlalchemy.sql import text
+import json
+
+
+def get_boundary_geojson_by_name(db: Session, name: str, level: str = "auto") -> dict:
+    """根据名称获取边界GeoJSON Feature
+
+    Args:
+        db (Session): 数据库会话
+        name (str): 名称(县/市/省)
+        level (str): 层级,可选 "county"|"city"|"province"|"auto"
+
+    Returns:
+        dict: GeoJSON Feature 对象
+    """
+    # county 精确匹配
+    if level in ("county", "auto"):
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name, city_name, province_name
+            FROM counties WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "county",
+                    "name": r.name,
+                    "city": r.city_name,
+                    "province": r.province_name,
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    # city 直接查询或从counties聚合
+    if level in ("city", "auto"):
+        # 优先从cities表直接查询
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name, province_name
+            FROM cities WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "city",
+                    "name": r.name,
+                    "province": r.province_name
+                },
+                "geometry": json.loads(r.g)
+            }
+        
+        # 如果cities表没有,从counties表聚合
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(ST_UnaryUnion(ST_Union(geometry))) AS g,
+                   MIN(province_name) AS province_name
+            FROM counties WHERE city_name=:n
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "city",
+                    "name": name,
+                    "province": r.province_name
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    # province 直接查询或从counties聚合
+    if level in ("province", "auto"):
+        # 优先从provinces表直接查询
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name
+            FROM provinces WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "province",
+                    "name": r.name
+                },
+                "geometry": json.loads(r.g)
+            }
+        
+        # 如果provinces表没有,从counties表聚合
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(ST_UnaryUnion(ST_Union(geometry))) AS g
+            FROM counties WHERE province_name=:n
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "province",
+                    "name": name
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    raise ValueError(f"未找到名称: {name}")
+
+

+ 63 - 2
app/services/cd_prediction_service.py

@@ -19,6 +19,9 @@ import io
 
 
 from ..config.cd_prediction_config import cd_config
 from ..config.cd_prediction_config import cd_config
 from ..utils.cd_prediction_wrapper import CdPredictionWrapper
 from ..utils.cd_prediction_wrapper import CdPredictionWrapper
+from ..database import SessionLocal
+from .admin_boundary_service import get_boundary_geojson_by_name
+import json
 
 
 class CdPredictionService:
 class CdPredictionService:
     """
     """
@@ -550,6 +553,25 @@ class CdPredictionService:
         @returns {Dict[str, Any]} 预测结果信息
         @returns {Dict[str, Any]} 预测结果信息
         """
         """
         try:
         try:
+            # 用数据库边界覆盖环境变量给集成系统
+            tmp_geojson = None
+            try:
+                db = SessionLocal()
+                feature = get_boundary_geojson_by_name(db, county_name, level="auto")
+                fc = {"type": "FeatureCollection", "features": [feature]}
+                tmp_dir = tempfile.mkdtemp()
+                tmp_geojson = os.path.join(tmp_dir, "boundary.geojson")
+                with open(tmp_geojson, 'w', encoding='utf-8') as f:
+                    json.dump(fc, f, ensure_ascii=False)
+                os.environ['CD_BOUNDARY_FILE'] = tmp_geojson
+            except Exception as _e:
+                self.logger.warning(f"从数据库获取边界失败,回退到默认配置: {str(_e)}")
+            finally:
+                try:
+                    db.close()
+                except Exception:
+                    pass
+
             # 运行作物Cd预测
             # 运行作物Cd预测
             self.logger.info(f"为{county_name}执行作物Cd预测")
             self.logger.info(f"为{county_name}执行作物Cd预测")
             prediction_result = self.wrapper.run_prediction_script("crop", raster_config_override)
             prediction_result = self.wrapper.run_prediction_script("crop", raster_config_override)
@@ -565,7 +587,7 @@ class CdPredictionService:
             # 清理旧文件
             # 清理旧文件
             self._cleanup_old_files(model_type)
             self._cleanup_old_files(model_type)
             
             
-            return {
+            result_obj = {
                 'map_path': copied_files.get('map_path'),
                 'map_path': copied_files.get('map_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'raster_path': copied_files.get('raster_path'),
                 'raster_path': copied_files.get('raster_path'),
@@ -574,6 +596,16 @@ class CdPredictionService:
                 'timestamp': timestamp,
                 'timestamp': timestamp,
                 'stats': self._get_file_stats(copied_files.get('map_path'))
                 'stats': self._get_file_stats(copied_files.get('map_path'))
             }
             }
+            # 清理临时边界
+            try:
+                if tmp_geojson and os.path.exists(tmp_geojson):
+                    import shutil
+                    shutil.rmtree(os.path.dirname(tmp_geojson), ignore_errors=True)
+                if 'CD_BOUNDARY_FILE' in os.environ:
+                    del os.environ['CD_BOUNDARY_FILE']
+            except Exception:
+                pass
+            return result_obj
             
             
         except Exception as e:
         except Exception as e:
             self.logger.error(f"为{county_name}执行作物Cd预测失败: {str(e)}")
             self.logger.error(f"为{county_name}执行作物Cd预测失败: {str(e)}")
@@ -590,6 +622,25 @@ class CdPredictionService:
         @returns {Dict[str, Any]} 预测结果信息
         @returns {Dict[str, Any]} 预测结果信息
         """
         """
         try:
         try:
+            # 用数据库边界覆盖环境变量给集成系统
+            tmp_geojson = None
+            try:
+                db = SessionLocal()
+                feature = get_boundary_geojson_by_name(db, county_name, level="auto")
+                fc = {"type": "FeatureCollection", "features": [feature]}
+                tmp_dir = tempfile.mkdtemp()
+                tmp_geojson = os.path.join(tmp_dir, "boundary.geojson")
+                with open(tmp_geojson, 'w', encoding='utf-8') as f:
+                    json.dump(fc, f, ensure_ascii=False)
+                os.environ['CD_BOUNDARY_FILE'] = tmp_geojson
+            except Exception as _e:
+                self.logger.warning(f"从数据库获取边界失败,回退到默认配置: {str(_e)}")
+            finally:
+                try:
+                    db.close()
+                except Exception:
+                    pass
+
             # 运行有效态Cd预测
             # 运行有效态Cd预测
             self.logger.info(f"为{county_name}执行有效态Cd预测")
             self.logger.info(f"为{county_name}执行有效态Cd预测")
             prediction_result = self.wrapper.run_prediction_script("effective", raster_config_override)
             prediction_result = self.wrapper.run_prediction_script("effective", raster_config_override)
@@ -605,7 +656,7 @@ class CdPredictionService:
             # 清理旧文件
             # 清理旧文件
             self._cleanup_old_files(model_type)
             self._cleanup_old_files(model_type)
             
             
-            return {
+            result_obj = {
                 'map_path': copied_files.get('map_path'),
                 'map_path': copied_files.get('map_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'raster_path': copied_files.get('raster_path'),
                 'raster_path': copied_files.get('raster_path'),
@@ -614,6 +665,16 @@ class CdPredictionService:
                 'timestamp': timestamp,
                 'timestamp': timestamp,
                 'stats': self._get_file_stats(copied_files.get('map_path'))
                 'stats': self._get_file_stats(copied_files.get('map_path'))
             }
             }
+            # 清理临时边界
+            try:
+                if tmp_geojson and os.path.exists(tmp_geojson):
+                    import shutil
+                    shutil.rmtree(os.path.dirname(tmp_geojson), ignore_errors=True)
+                if 'CD_BOUNDARY_FILE' in os.environ:
+                    del os.environ['CD_BOUNDARY_FILE']
+            except Exception:
+                pass
+            return result_obj
             
             
         except Exception as e:
         except Exception as e:
             self.logger.error(f"为{county_name}执行有效态Cd预测失败: {str(e)}")
             self.logger.error(f"为{county_name}执行有效态Cd预测失败: {str(e)}")

+ 66 - 0
migrations/versions/d89bb52ae481_add_provinces_and_cities_tables.py

@@ -0,0 +1,66 @@
+"""add_provinces_and_cities_tables
+
+Revision ID: d89bb52ae481
+Revises: e52e802fe61a
+Create Date: 2025-08-09 19:40:59.929438
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import geoalchemy2
+
+# revision identifiers, used by Alembic.
+revision = 'd89bb52ae481'
+down_revision = 'e52e802fe61a'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    """升级数据库到当前版本"""
+    # 创建provinces表
+    op.create_table('provinces',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(length=100), nullable=False, comment='省名'),
+    sa.Column('code', sa.Integer(), nullable=False, comment='省代码'),
+    sa.Column('geometry', geoalchemy2.types.Geometry(geometry_type='MULTIPOLYGON', srid=4326, from_text='ST_GeomFromEWKT', name='geometry', nullable=False), nullable=False, comment='省级行政区划的几何数据'),
+    sa.Column('geojson', sa.JSON(), nullable=False, comment='完整的GeoJSON数据'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('code')
+    )
+    op.create_index('idx_provinces_code', 'provinces', ['code'], unique=False)
+    op.create_index('idx_provinces_name', 'provinces', ['name'], unique=False)
+    op.create_index(op.f('ix_provinces_id'), 'provinces', ['id'], unique=False)
+    
+    # 创建cities表
+    op.create_table('cities',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(length=100), nullable=False, comment='市名'),
+    sa.Column('code', sa.Integer(), nullable=False, comment='市代码'),
+    sa.Column('province_name', sa.String(length=100), nullable=False, comment='省名'),
+    sa.Column('province_code', sa.Integer(), nullable=False, comment='省代码'),
+    sa.Column('geometry', geoalchemy2.types.Geometry(geometry_type='MULTIPOLYGON', srid=4326, from_text='ST_GeomFromEWKT', name='geometry', nullable=False), nullable=False, comment='市级行政区划的几何数据'),
+    sa.Column('geojson', sa.JSON(), nullable=False, comment='完整的GeoJSON数据'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('code')
+    )
+    op.create_index('idx_cities_code', 'cities', ['code'], unique=False)
+    op.create_index('idx_cities_name', 'cities', ['name'], unique=False)
+    op.create_index('idx_cities_province_code', 'cities', ['province_code'], unique=False)
+    op.create_index(op.f('ix_cities_id'), 'cities', ['id'], unique=False)
+
+
+def downgrade():
+    """降级数据库到上一个版本"""
+    # 删除cities表
+    op.drop_index(op.f('ix_cities_id'), table_name='cities')
+    op.drop_index('idx_cities_province_code', table_name='cities')
+    op.drop_index('idx_cities_name', table_name='cities')
+    op.drop_index('idx_cities_code', table_name='cities')
+    op.drop_table('cities')
+
+    # 删除provinces表
+    op.drop_index(op.f('ix_provinces_id'), table_name='provinces')
+    op.drop_index('idx_provinces_name', table_name='provinces')
+    op.drop_index('idx_provinces_code', table_name='provinces')
+    op.drop_table('provinces')