""" Pytest configuration and fixtures for integration tests. Provides test database setup, client fixtures, and authentication helpers. """ import os from typing import AsyncGenerator import pytest import sqlalchemy as sa from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.pool import NullPool from app.main import app from app.core.database import Base, get_db from app.core.admin_database import get_admin_db from app.core.config import settings # Disable invite code requirement for tests settings.REQUIRE_INVITE_CODE = False # Test database URL — NEVER reuse DATABASE_URL. The test_db fixture does # `DROP SCHEMA public CASCADE` on every test; if DATABASE_URL (which normally # points at the dev/prod DB) leaked into this value, running `pytest tests/` # would silently nuke the dev database. Only DATABASE_TEST_URL is honored, # and the safety assertion below refuses to run against a DB whose name # doesn't contain "test". TEST_DATABASE_URL = os.environ.get( "DATABASE_TEST_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/resolutionflow_test", ) # Belt-and-suspenders: refuse to run tests against a DB whose name doesn't # contain "test". Parses the last path segment of the URL (everything after # the final '/', with query string stripped) so credentials / hosts that # happen to contain "test" can't bypass the check. _test_db_name = TEST_DATABASE_URL.rsplit("/", 1)[-1].split("?", 1)[0].lower() assert "test" in _test_db_name, ( f"Refusing to run tests against database {_test_db_name!r} — " f"the DB name must contain 'test'. Set DATABASE_TEST_URL to a dedicated " f"test database (e.g. resolutionflow_test)." ) _RUN_RLS_TESTS = os.environ.get("RUN_RLS_TESTS") == "1" _RLS_ISOLATION_FILE = "test_rls_isolation.py" def pytest_collection_modifyitems(config, items): """Keep migration-managed RLS checks out of the default create_all suite.""" if _RUN_RLS_TESTS: return selected = [] deselected = [] for item in items: item_path = getattr(item, "path", None) or getattr(item, "fspath", None) if item_path and str(item_path).endswith(_RLS_ISOLATION_FILE): deselected.append(item) else: selected.append(item) if deselected: config.hook.pytest_deselected(items=deselected) items[:] = selected @pytest.fixture async def test_db() -> AsyncGenerator[AsyncSession, None]: """ Create a fresh database for each test function. This fixture: 1. Creates a test database engine 2. Drops all existing tables (CASCADE to handle circular FKs) 3. Creates all tables 4. Yields a session for the test 5. Drops all tables after the test """ # Create async engine for tests (with NullPool to avoid connection reuse issues) engine = create_async_engine( TEST_DATABASE_URL, poolclass=NullPool, echo=False ) # Drop and recreate all tables (use raw SQL CASCADE to handle circular FKs # between users <-> invite_codes) async with engine.begin() as conn: await conn.execute(sa.text("DROP SCHEMA public CASCADE")) await conn.execute(sa.text("CREATE SCHEMA public")) await conn.run_sync(Base.metadata.create_all) # Seed plan_limits for subscription checks await conn.execute(sa.text(""" INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) VALUES ('free', 3, 20, 1, false, false, '["markdown", "text"]'), ('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'), ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') """)) # Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by # global categories, gallery items, and other platform-owned content. await conn.execute(sa.text(""" INSERT INTO accounts (id, name, display_code, created_at, updated_at) VALUES ( '00000000-0000-0000-0000-000000000001', 'ResolutionFlow System', 'RF-SYS-1', NOW(), NOW() ) ON CONFLICT (id) DO NOTHING """)) # Create async session maker async_session_maker = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) # Provide session to test async with async_session_maker() as session: yield session # Ensure session is fully closed before teardown await session.close() # Dispose engine first so all pooled connections are released, # then reconnect to perform the schema teardown cleanly. await engine.dispose() # Drop all tables after test (CASCADE for circular FKs) teardown_engine = create_async_engine( TEST_DATABASE_URL, poolclass=NullPool, echo=False, ) try: async with teardown_engine.begin() as conn: await conn.execute(sa.text("DROP SCHEMA public CASCADE")) await conn.execute(sa.text("CREATE SCHEMA public")) finally: await teardown_engine.dispose() @pytest.fixture async def client(test_db: AsyncSession): """ Create an async HTTP client for testing API endpoints. Overrides the database dependency to use the test database. """ async def override_get_db(): yield test_db app.dependency_overrides[get_db] = override_get_db # Endpoints that use get_admin_db (register, admin routes, service accounts) # must also hit the test DB; otherwise they leak into the real admin DB. # RLS is not enabled in the test schema (create_all, not alembic), so sharing # the same session is safe. app.dependency_overrides[get_admin_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() @pytest.fixture async def test_user(client): """ Create a test user and return their credentials. Returns: dict with email, password, and user_data """ user_data = { "email": "test@example.com", "password": "TestPassword123!", "name": "Test User" } response = await client.post("/api/v1/auth/register", json=user_data) assert response.status_code == 200 or response.status_code == 201 return { "email": user_data["email"], "password": user_data["password"], "user_data": response.json() } @pytest.fixture async def auth_headers(client, test_user): """ Get authentication headers for an authenticated test user. Returns: dict with Authorization header """ login_data = { "email": test_user["email"], "password": test_user["password"] } response = await client.post("/api/v1/auth/login/json", json=login_data) assert response.status_code == 200 token_data = response.json() return {"Authorization": f"Bearer {token_data['access_token']}"} @pytest.fixture async def test_tree(client, auth_headers): """ Create a test decision tree. Returns: dict with tree data """ tree_data = { "name": "Test Troubleshooting Tree", "description": "A test tree for integration tests", "category": "Testing", "tree_structure": { "id": "root", "type": "decision", "question": "Is this a test?", "options": [ {"id": "yes", "label": "Yes", "next_node_id": "solution1"}, {"id": "no", "label": "No", "next_node_id": "solution2"} ], "children": [ { "id": "solution1", "type": "solution", "title": "Test Confirmed", "description": "This is a test tree", "solution": "Test confirmed - this is a test tree" }, { "id": "solution2", "type": "solution", "title": "Not a Test", "description": "This should not happen", "solution": "Not a test - this should not happen" } ] } } response = await client.post( "/api/v1/trees", json=tree_data, headers=auth_headers ) assert response.status_code == 201 return response.json() @pytest.fixture async def test_admin(client, test_db): """ Create a test super-admin user. Registers as engineer (the only role available at registration), then promotes to super_admin directly via the DB session. """ from uuid import UUID as PyUUID from sqlalchemy import select from app.models.user import User admin_data = { "email": "admin@example.com", "password": "AdminPassword123!", "name": "Test Admin" } response = await client.post("/api/v1/auth/register", json=admin_data) assert response.status_code == 200 or response.status_code == 201 user_id = PyUUID(response.json()["id"]) result = await test_db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() user.is_super_admin = True await test_db.commit() return { "email": admin_data["email"], "password": admin_data["password"], "user_data": response.json() } @pytest.fixture async def admin_auth_headers(client, test_admin): """ Get authentication headers for an authenticated admin user. Returns: dict with Authorization header """ login_data = { "email": test_admin["email"], "password": test_admin["password"] } response = await client.post("/api/v1/auth/login/json", json=login_data) assert response.status_code == 200 token_data = response.json() return {"Authorization": f"Bearer {token_data['access_token']}"}