#!/usr/bin/env python # -*- coding: utf-8 -*- """ 单元分类功能演示脚本 该脚本帮助用户快速启动和测试单元分类功能 """ import os import sys import subprocess import time import threading from pathlib import Path def print_banner(): """打印横幅""" print("=" * 70) print("单元分类功能演示") print("=" * 70) print() def check_dependencies(): """检查依赖项""" print("🔍 检查依赖项...") required_packages = [ 'fastapi', 'uvicorn', 'sqlalchemy', 'shapely', 'psycopg2', 'geoalchemy2' ] missing_packages = [] for package in required_packages: try: __import__(package) print(f"✓ {package}") except ImportError: missing_packages.append(package) print(f"✗ {package} (缺失)") if missing_packages: print("\n❌ 缺少依赖项,请先安装:") print("conda env create -f environment.yml") print("或手动安装: pip install " + " ".join(missing_packages)) return False print("✓ 所有依赖项已安装") return True def check_database_config(): """检查数据库配置""" print("\n🔍 检查数据库配置...") # 回到项目根目录查找配置文件 project_root = Path(__file__).parent.parent.parent config_file = project_root / "config.env" if not config_file.exists(): print("❌ 配置文件 config.env 不存在") return False # 检查基本配置项 required_vars = ['DB_HOST', 'DB_PORT', 'DB_NAME', 'DB_USER', 'DB_PASSWORD'] try: from dotenv import load_dotenv # type: ignore load_dotenv(str(config_file)) for var in required_vars: if not os.getenv(var): print(f"❌ 环境变量 {var} 未设置") return False print("✓ 数据库配置正确") return True except Exception as e: print(f"❌ 检查数据库配置时出错: {e}") return False def start_server(): """启动服务器""" print("\n🚀 启动FastAPI服务器...") print("服务器将在 http://localhost:8000 启动") print("API文档地址: http://localhost:8000/docs") print("按 Ctrl+C 停止服务器") print("-" * 50) try: # 切换到项目根目录 project_root = Path(__file__).parent.parent.parent os.chdir(project_root) # 使用subprocess启动uvicorn process = subprocess.Popen([ sys.executable, '-m', 'uvicorn', 'main:app', '--reload', '--host', '0.0.0.0', '--port', '8000' ]) # 等待服务器启动 time.sleep(3) return process except Exception as e: print(f"❌ 启动服务器失败: {e}") return None def run_tests(): """运行测试""" print("\n🧪 运行集成测试...") print("-" * 50) try: # 等待服务器完全启动 time.sleep(2) # 运行测试脚本 project_root = Path(__file__).parent.parent.parent test_file = project_root / "tests" / "integration" / "test_unit_grouping.py" if test_file.exists(): subprocess.run([sys.executable, str(test_file)]) else: print(f"❌ 测试文件不存在: {test_file}") except Exception as e: print(f"❌ 运行测试失败: {e}") def show_usage_info(): """显示使用信息""" print("\n📖 使用说明:") print("-" * 50) print("1. 主要API端点:") print(" GET /api/unit-grouping/h_xtfx - 获取所有单元的h_xtfx分类") print(" GET /api/unit-grouping/statistics - 获取统计信息") print(" GET /api/unit-grouping/unit/{unit_id} - 获取特定单元的h_xtfx值") print() print("2. 测试示例:") print(" curl http://localhost:8000/api/unit-grouping/h_xtfx") print(" curl http://localhost:8000/api/unit-grouping/statistics") print() print("3. 查看完整API文档:") print(" http://localhost:8000/docs") print() print("4. 重新运行测试:") print(" python tests/integration/test_unit_grouping.py") print() print("5. 查看功能文档:") print(" docs/features/unit-grouping/README.md") def main(): """主函数""" print_banner() # 检查依赖项 if not check_dependencies(): return # 检查数据库配置 if not check_database_config(): return # 启动服务器 server_process = start_server() if not server_process: return try: # 运行测试 test_thread = threading.Thread(target=run_tests) test_thread.daemon = True test_thread.start() # 等待测试完成 test_thread.join(timeout=30) # 显示使用信息 show_usage_info() print("\n🎉 服务器正在运行中...") print("按 Ctrl+C 退出") # 等待用户中断 server_process.wait() except KeyboardInterrupt: print("\n\n👋 正在关闭服务器...") server_process.terminate() server_process.wait() print("服务器已关闭") except Exception as e: print(f"\n❌ 运行时出错: {e}") if server_process: server_process.terminate() if __name__ == "__main__": main()