Merge feat/l1-workspace into integration branch
# Conflicts: # frontend/src/router.tsx
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
"""create_internal_tickets
|
||||
|
||||
Revision ID: a1e6a018af02
|
||||
Revises: ff6fe5895ea2
|
||||
Create Date: 2026-05-28 16:29:32.624317
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a1e6a018af02'
|
||||
down_revision: Union[str, None] = 'ff6fe5895ea2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
_CURRENT_ACCOUNT = (
|
||||
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
f"'{_NULL_UUID}')::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'internal_tickets',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('account_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_by_user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('customer_name', sa.String(120), nullable=True),
|
||||
sa.Column('customer_contact', sa.String(200), nullable=True),
|
||||
sa.Column('problem_statement', sa.Text(), nullable=False),
|
||||
sa.Column('status', sa.String(30), nullable=False, server_default='open'),
|
||||
sa.Column('flow_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('flow_proposal_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('ai_session_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('assigned_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('resolution_notes', sa.Text(), nullable=True),
|
||||
sa.Column('psa_promoted_ticket_id', sa.String(64), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ondelete='RESTRICT'),
|
||||
sa.ForeignKeyConstraint(['flow_id'], ['trees.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['flow_proposal_id'], ['flow_proposals.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['ai_session_id'], ['ai_sessions.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['assigned_user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.CheckConstraint(
|
||||
"status IN ('open', 'walking', 'resolved', 'escalated')",
|
||||
name='ck_internal_tickets_status',
|
||||
),
|
||||
)
|
||||
op.create_index('ix_internal_tickets_account_id', 'internal_tickets', ['account_id'])
|
||||
op.create_index('ix_internal_tickets_status', 'internal_tickets', ['status'])
|
||||
op.create_index('ix_internal_tickets_assigned_user_id', 'internal_tickets', ['assigned_user_id'])
|
||||
|
||||
op.execute("ALTER TABLE internal_tickets ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE internal_tickets FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON internal_tickets
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON internal_tickets")
|
||||
op.execute("ALTER TABLE internal_tickets DISABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE internal_tickets NO FORCE ROW LEVEL SECURITY")
|
||||
op.drop_index('ix_internal_tickets_assigned_user_id', 'internal_tickets')
|
||||
op.drop_index('ix_internal_tickets_status', 'internal_tickets')
|
||||
op.drop_index('ix_internal_tickets_account_id', 'internal_tickets')
|
||||
op.drop_table('internal_tickets')
|
||||
59
backend/alembic/versions/a8186f22506d_add_l1_columns.py
Normal file
59
backend/alembic/versions/a8186f22506d_add_l1_columns.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""add_l1_columns
|
||||
|
||||
Revision ID: a8186f22506d
|
||||
Revises: b269a1add160
|
||||
Create Date: 2026-05-28 16:15:40.900535
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a8186f22506d'
|
||||
down_revision: Union[str, None] = 'b269a1add160'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column('can_cover_l1', sa.Boolean(), nullable=False, server_default='false'),
|
||||
)
|
||||
op.add_column(
|
||||
'accounts',
|
||||
sa.Column('l1_seats_purchased', sa.Integer(), nullable=False, server_default='0'),
|
||||
)
|
||||
op.add_column(
|
||||
'subscriptions',
|
||||
sa.Column('l1_seat_limit', sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'audit_logs',
|
||||
sa.Column('acting_as', sa.String(30), nullable=True),
|
||||
)
|
||||
|
||||
# Rotate account_role CHECK constraint to include 'l1_tech'
|
||||
op.drop_constraint('ck_users_account_role_enum', 'users', type_='check')
|
||||
op.create_check_constraint(
|
||||
'ck_users_account_role_enum',
|
||||
'users',
|
||||
"account_role IN ('owner', 'admin', 'engineer', 'l1_tech', 'viewer')",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse the constraint rotation first
|
||||
op.drop_constraint('ck_users_account_role_enum', 'users', type_='check')
|
||||
op.create_check_constraint(
|
||||
'ck_users_account_role_enum',
|
||||
'users',
|
||||
"account_role IN ('owner', 'admin', 'engineer', 'viewer')",
|
||||
)
|
||||
op.drop_column('audit_logs', 'acting_as')
|
||||
op.drop_column('subscriptions', 'l1_seat_limit')
|
||||
op.drop_column('accounts', 'l1_seats_purchased')
|
||||
op.drop_column('users', 'can_cover_l1')
|
||||
@@ -0,0 +1,97 @@
|
||||
"""create_l1_walk_sessions
|
||||
|
||||
Revision ID: b3358ba0e48c
|
||||
Revises: a1e6a018af02
|
||||
Create Date: 2026-05-28 16:33:52.120027
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b3358ba0e48c'
|
||||
down_revision: Union[str, None] = 'a1e6a018af02'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
_CURRENT_ACCOUNT = (
|
||||
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
f"'{_NULL_UUID}')::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'l1_walk_sessions',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('account_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_by_user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('acting_as', sa.String(30), nullable=True),
|
||||
sa.Column('ticket_id', sa.String(64), nullable=False),
|
||||
sa.Column('ticket_kind', sa.String(10), nullable=False),
|
||||
sa.Column('session_kind', sa.String(20), nullable=False),
|
||||
sa.Column('flow_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('flow_proposal_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('current_node_id', sa.String(100), nullable=True),
|
||||
sa.Column('walked_path', postgresql.JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column('walk_notes', postgresql.JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column('status', sa.String(20), nullable=False, server_default='active'),
|
||||
sa.Column('resolution_notes', sa.Text(), nullable=True),
|
||||
sa.Column('helpful', sa.Boolean(), nullable=True),
|
||||
sa.Column('escalation_reason', sa.Text(), nullable=True),
|
||||
sa.Column('escalation_reason_category', sa.String(30), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('last_step_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ondelete='RESTRICT'),
|
||||
sa.ForeignKeyConstraint(['flow_id'], ['trees.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['flow_proposal_id'], ['flow_proposals.id'], ondelete='SET NULL'),
|
||||
sa.CheckConstraint(
|
||||
"ticket_kind IN ('psa', 'internal')",
|
||||
name='ck_l1_walk_sessions_ticket_kind',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"session_kind IN ('flow', 'proposal', 'adhoc')",
|
||||
name='ck_l1_walk_sessions_session_kind',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"status IN ('active', 'resolved', 'escalated', 'abandoned')",
|
||||
name='ck_l1_walk_sessions_status',
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"(session_kind = 'flow' AND flow_id IS NOT NULL AND flow_proposal_id IS NULL) "
|
||||
"OR (session_kind = 'proposal' AND flow_proposal_id IS NOT NULL AND flow_id IS NULL) "
|
||||
"OR (session_kind = 'adhoc' AND flow_id IS NULL AND flow_proposal_id IS NULL)",
|
||||
name='ck_l1_walk_sessions_target_consistency',
|
||||
),
|
||||
)
|
||||
op.create_index('ix_l1_walk_sessions_account_id', 'l1_walk_sessions', ['account_id'])
|
||||
op.create_index('ix_l1_walk_sessions_created_by_user_id', 'l1_walk_sessions', ['created_by_user_id'])
|
||||
op.create_index('ix_l1_walk_sessions_status', 'l1_walk_sessions', ['status'])
|
||||
op.create_index('ix_l1_walk_sessions_last_step_at', 'l1_walk_sessions', ['last_step_at'])
|
||||
|
||||
op.execute("ALTER TABLE l1_walk_sessions ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE l1_walk_sessions FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON l1_walk_sessions
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON l1_walk_sessions")
|
||||
op.execute("ALTER TABLE l1_walk_sessions DISABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE l1_walk_sessions NO FORCE ROW LEVEL SECURITY")
|
||||
op.drop_index('ix_l1_walk_sessions_last_step_at', 'l1_walk_sessions')
|
||||
op.drop_index('ix_l1_walk_sessions_status', 'l1_walk_sessions')
|
||||
op.drop_index('ix_l1_walk_sessions_created_by_user_id', 'l1_walk_sessions')
|
||||
op.drop_index('ix_l1_walk_sessions_account_id', 'l1_walk_sessions')
|
||||
op.drop_table('l1_walk_sessions')
|
||||
@@ -0,0 +1,52 @@
|
||||
"""extend_flow_proposals_l1
|
||||
|
||||
Revision ID: ff6fe5895ea2
|
||||
Revises: a8186f22506d
|
||||
Create Date: 2026-05-28 16:26:06.932886
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ff6fe5895ea2'
|
||||
down_revision: Union[str, None] = 'a8186f22506d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('flow_proposals', sa.Column('source', sa.String(30), nullable=True))
|
||||
op.add_column('flow_proposals', sa.Column('linked_ticket_id', sa.String(64), nullable=True))
|
||||
op.add_column('flow_proposals', sa.Column('linked_ticket_kind', sa.String(10), nullable=True))
|
||||
op.add_column(
|
||||
'flow_proposals',
|
||||
sa.Column('validated_by_outcome', sa.Boolean(), nullable=False, server_default='false'),
|
||||
)
|
||||
|
||||
# Backfill existing rows then enforce NOT NULL on source
|
||||
op.execute("UPDATE flow_proposals SET source = 'manual_draft' WHERE source IS NULL")
|
||||
op.alter_column('flow_proposals', 'source', nullable=False)
|
||||
|
||||
op.create_check_constraint(
|
||||
'ck_flow_proposals_source',
|
||||
'flow_proposals',
|
||||
"source IN ('ai_realtime_l1', 'kb_accelerator', 'manual_draft', 'ai_promoted')",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
'ck_flow_proposals_linked_ticket_kind',
|
||||
'flow_proposals',
|
||||
"linked_ticket_kind IS NULL OR linked_ticket_kind IN ('psa', 'internal')",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint('ck_flow_proposals_linked_ticket_kind', 'flow_proposals', type_='check')
|
||||
op.drop_constraint('ck_flow_proposals_source', 'flow_proposals', type_='check')
|
||||
op.drop_column('flow_proposals', 'validated_by_outcome')
|
||||
op.drop_column('flow_proposals', 'linked_ticket_kind')
|
||||
op.drop_column('flow_proposals', 'linked_ticket_id')
|
||||
op.drop_column('flow_proposals', 'source')
|
||||
@@ -199,6 +199,53 @@ async def require_engineer_or_admin(
|
||||
)
|
||||
|
||||
|
||||
async def require_l1(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
"""L1 tech exact-match (with super_admin bypass for support)."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
if current_user.account_role != "l1_tech":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="L1 tech role required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_l1_or_coverage(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
"""L1 endpoints: l1_tech, owners, super_admin, or engineers with can_cover_l1=True."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
role = current_user.account_role
|
||||
if role == "l1_tech":
|
||||
return current_user
|
||||
if role == "owner":
|
||||
return current_user
|
||||
if role == "engineer" and current_user.can_cover_l1:
|
||||
return current_user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="L1 access requires l1_tech role or engineer coverage flag",
|
||||
)
|
||||
|
||||
|
||||
async def require_l1_or_above(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
"""Any tier from l1_tech upward (l1_tech, engineer, owner, super_admin)."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
if current_user.account_role in ("l1_tech", "engineer", "owner"):
|
||||
return current_user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="L1 or above required",
|
||||
)
|
||||
|
||||
|
||||
async def require_team_admin(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
|
||||
@@ -21,13 +21,54 @@ from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, AccountInviteBulkCreate, AccountInviteBulkResponse, TransferOwnershipRequest
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate, CoverageUpdate
|
||||
from app.core.security import verify_password
|
||||
from app.api.deps import get_current_active_user, require_account_owner
|
||||
from app.api.deps import get_current_active_user, require_account_owner, require_engineer_or_admin
|
||||
from app.services.seat_enforcement import check_seat_available, get_seat_usage
|
||||
from app.schemas.seat_enforcement import SeatUsage
|
||||
|
||||
_SEAT_CHECKED_ROLES = frozenset({"engineer", "l1_tech"})
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["accounts"])
|
||||
|
||||
|
||||
async def _load_account(db: AsyncSession, account_id: UUID) -> Account:
|
||||
"""Load an Account by id; raises 404 if missing."""
|
||||
result = await db.execute(select(Account).where(Account.id == account_id))
|
||||
account = result.scalar_one_or_none()
|
||||
if account is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
return account
|
||||
|
||||
|
||||
async def _enforce_seat_limit(db: AsyncSession, account_id: UUID, role: str) -> None:
|
||||
"""Raise HTTP 402 if the account has no capacity for the given role.
|
||||
|
||||
Only fires for seat-counted roles (engineer, l1_tech).
|
||||
Accounts without a subscription (free / pre-billing) are not blocked.
|
||||
Grandfathering: if current > limit, existing users keep access; this
|
||||
helper only blocks new additions.
|
||||
"""
|
||||
if role not in _SEAT_CHECKED_ROLES:
|
||||
return
|
||||
sub = await get_account_subscription(account_id, db)
|
||||
if sub is None:
|
||||
return # no subscription → no enforcement
|
||||
account = await _load_account(db, account_id)
|
||||
seat_result = await check_seat_available(account, sub, role, db)
|
||||
if not seat_result.available:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"code": "seat_limit_exceeded",
|
||||
"role": seat_result.role,
|
||||
"current": seat_result.current,
|
||||
"limit": seat_result.limit,
|
||||
"upgrade_url": "/account/billing",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=AccountResponse)
|
||||
async def get_my_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
@@ -88,6 +129,41 @@ async def get_my_members(
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/me/seats", response_model=SeatUsage)
|
||||
async def get_my_account_seat_usage(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||
):
|
||||
"""Returns engineer + l1_tech seat-usage counts. Accessible to engineer+.
|
||||
|
||||
Powers the SeatCounterWidget on admin/users and account/users surfaces.
|
||||
"""
|
||||
account = await _load_account(db, current_user.account_id)
|
||||
sub = await get_account_subscription(current_user.account_id, db)
|
||||
if sub is None:
|
||||
# No subscription → treat as unlimited; return live counts with no limit
|
||||
from sqlalchemy import func
|
||||
engineer_count = (await db.execute(
|
||||
select(func.count(User.id))
|
||||
.where(User.account_id == account.id)
|
||||
.where(User.account_role == "engineer")
|
||||
.where(User.is_active.is_(True))
|
||||
)).scalar_one()
|
||||
l1_count = (await db.execute(
|
||||
select(func.count(User.id))
|
||||
.where(User.account_id == account.id)
|
||||
.where(User.account_role == "l1_tech")
|
||||
.where(User.is_active.is_(True))
|
||||
)).scalar_one()
|
||||
from app.schemas.seat_enforcement import SeatCheckResult
|
||||
return SeatUsage(
|
||||
engineer=SeatCheckResult(available=True, current=engineer_count, limit=None, role="engineer"),
|
||||
l1_tech=SeatCheckResult(available=True, current=l1_count, limit=None, role="l1_tech"),
|
||||
)
|
||||
engineer, l1_tech = await get_seat_usage(account, sub, db)
|
||||
return SeatUsage(engineer=engineer, l1_tech=l1_tech)
|
||||
|
||||
|
||||
@router.patch("/me", response_model=AccountResponse)
|
||||
async def update_my_account(
|
||||
data: AccountUpdate,
|
||||
@@ -141,12 +217,54 @@ async def update_member_role(
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
# Seat enforcement: check capacity before promoting to a seat-counted role.
|
||||
# Demotions (engineer/l1_tech → viewer) and lateral moves skip the check.
|
||||
if data.account_role != user.account_role:
|
||||
await _enforce_seat_limit(db, current_user.account_id, data.account_role)
|
||||
|
||||
user.account_role = data.account_role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/me/members/{user_id}/coverage", response_model=UserResponse)
|
||||
async def update_member_coverage(
|
||||
user_id: UUID,
|
||||
data: CoverageUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)],
|
||||
):
|
||||
"""Toggle the `can_cover_l1` flag on an engineer in your account.
|
||||
|
||||
Owner-only. Returns 404 if target user not in your account. Returns 422
|
||||
if target user's role is not 'engineer' (coverage flag only applies to
|
||||
engineers — owners/super_admins already see L1 surface; viewers/l1_techs
|
||||
don't need this flag).
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target = result.scalar_one_or_none()
|
||||
if target is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account",
|
||||
)
|
||||
if target.account_role != "engineer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="can_cover_l1 only applies to engineers",
|
||||
)
|
||||
target.can_cover_l1 = data.can_cover_l1
|
||||
await db.commit()
|
||||
await db.refresh(target)
|
||||
return target
|
||||
|
||||
|
||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||
async def transfer_ownership(
|
||||
data: TransferOwnershipRequest,
|
||||
@@ -261,6 +379,9 @@ async def create_invite(
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create an invite to join this account (owner only). Sends invite email."""
|
||||
# Seat enforcement: block invite if the target role is at capacity.
|
||||
await _enforce_seat_limit(db, current_user.account_id, data.role)
|
||||
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
expires_at = None
|
||||
@@ -317,6 +438,10 @@ async def create_invites_bulk(
|
||||
failed: list[dict] = []
|
||||
for invite_data in payload.invites:
|
||||
try:
|
||||
# Seat enforcement per invite row — 402 bubbles as an HTTPException
|
||||
# which is caught below and recorded in `failed`.
|
||||
await _enforce_seat_limit(db, current_user.account_id, invite_data.role)
|
||||
|
||||
code = secrets.token_urlsafe(16)
|
||||
expires_at = None
|
||||
if invite_data.expires_in_days:
|
||||
@@ -343,6 +468,8 @@ async def create_invites_bulk(
|
||||
invite.email_sent_at = datetime.now(timezone.utc)
|
||||
|
||||
created.append(invite)
|
||||
except HTTPException as exc:
|
||||
failed.append({"email": invite_data.email, "error": exc.detail})
|
||||
except Exception as e:
|
||||
failed.append({"email": invite_data.email, "error": str(e)})
|
||||
|
||||
|
||||
@@ -289,6 +289,33 @@ async def register(
|
||||
detail="Invite code has expired"
|
||||
)
|
||||
|
||||
# Seat enforcement: re-check at accept time (race-condition guard).
|
||||
# Fires only when an account invite is being accepted and the target role
|
||||
# is seat-counted (engineer, l1_tech). Accounts without a subscription
|
||||
# (free / pre-billing) are not blocked.
|
||||
if account_invite_record and account_invite_record.role in ("engineer", "l1_tech"):
|
||||
from app.core.subscriptions import get_account_subscription
|
||||
from app.services.seat_enforcement import check_seat_available
|
||||
from app.models.account import Account as _Account
|
||||
sub = await get_account_subscription(account_invite_record.account_id, db)
|
||||
if sub is not None:
|
||||
acct_result = await db.execute(
|
||||
select(_Account).where(_Account.id == account_invite_record.account_id)
|
||||
)
|
||||
acct = acct_result.scalar_one()
|
||||
seat_result = await check_seat_available(acct, sub, account_invite_record.role, db)
|
||||
if not seat_result.available:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"code": "seat_limit_exceeded",
|
||||
"role": seat_result.role,
|
||||
"current": seat_result.current,
|
||||
"limit": seat_result.limit,
|
||||
"upgrade_url": "/account/billing",
|
||||
},
|
||||
)
|
||||
|
||||
# Check if email already exists
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
277
backend/app/api/endpoints/l1.py
Normal file
277
backend/app/api/endpoints/l1.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""L1 Workspace endpoints (Phase 1).
|
||||
|
||||
PSA-merge queue support + AI build path are deferred to Phase 2.
|
||||
"""
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status as http_status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_db, require_l1_or_coverage
|
||||
from app.models.l1_walk_session import L1WalkSession
|
||||
from app.models.user import User
|
||||
from app.schemas.l1 import (
|
||||
EscalateRequest,
|
||||
EscalateWithoutWalkRequest,
|
||||
IntakeRequest,
|
||||
IntakeResponse,
|
||||
NotesRequest,
|
||||
QueueRow,
|
||||
ResolveRequest,
|
||||
StepRequest,
|
||||
WalkSessionResponse,
|
||||
)
|
||||
from app.services import internal_ticket_service, l1_session_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/l1", tags=["l1"])
|
||||
|
||||
|
||||
def _to_response(session: L1WalkSession) -> WalkSessionResponse:
|
||||
return WalkSessionResponse(
|
||||
id=session.id,
|
||||
session_kind=session.session_kind,
|
||||
flow_id=session.flow_id,
|
||||
flow_proposal_id=session.flow_proposal_id,
|
||||
current_node_id=session.current_node_id,
|
||||
walked_path=session.walked_path or [],
|
||||
walk_notes=session.walk_notes or [],
|
||||
status=session.status,
|
||||
started_at=session.started_at,
|
||||
last_step_at=session.last_step_at,
|
||||
resolved_at=session.resolved_at,
|
||||
)
|
||||
|
||||
|
||||
async def _get_session_or_404(
|
||||
db: AsyncSession, session_id: UUID, user: User
|
||||
) -> L1WalkSession:
|
||||
"""Fetch a session by id, scoped to the caller's account.
|
||||
|
||||
Phase 1 policy (per spec §7.9): sessions are account-scoped, not
|
||||
user-scoped. Any L1 or coverage engineer in the same account can
|
||||
step/note/resolve/escalate any session — supports team coverage
|
||||
(e.g., L1 hands off mid-shift; coverage engineer takes over a call).
|
||||
For a stricter "creator-only" policy, add
|
||||
``created_by_user_id == user.id`` here.
|
||||
"""
|
||||
session = await db.get(L1WalkSession, session_id)
|
||||
if session is None or session.account_id != user.account_id:
|
||||
raise HTTPException(
|
||||
status_code=http_status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found",
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/intake", response_model=IntakeResponse)
|
||||
async def intake(
|
||||
payload: IntakeRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
"""L1 intake: creates an internal ticket and starts a walk session.
|
||||
|
||||
Phase 1: internal-ticket only (PSA support follows in Phase 2 escalation polish).
|
||||
If `flow_id` is provided, starts a flow session; otherwise an adhoc session.
|
||||
"""
|
||||
ticket = await internal_ticket_service.create_ticket(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
created_by_user_id=user.id,
|
||||
problem_statement=payload.problem_statement,
|
||||
customer_name=payload.customer_name,
|
||||
customer_contact=payload.customer_contact,
|
||||
)
|
||||
if payload.flow_id is not None:
|
||||
session = await l1_session_service.start_flow_session(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
user=user,
|
||||
flow_id=payload.flow_id,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
else:
|
||||
session = await l1_session_service.start_adhoc_session(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
user=user,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
await db.commit()
|
||||
return IntakeResponse(
|
||||
session_id=session.id,
|
||||
session_kind=session.session_kind,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue", response_model=list[QueueRow])
|
||||
async def queue(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
status_filter: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""Phase 1 queue: internal tickets only. PSA-fed rows in Phase 2."""
|
||||
tickets = await internal_ticket_service.list_tickets_for_account(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
status=status_filter,
|
||||
limit=limit,
|
||||
)
|
||||
return [
|
||||
QueueRow(
|
||||
ticket_id=str(t.id),
|
||||
ticket_kind="internal",
|
||||
problem_statement=t.problem_statement,
|
||||
customer_name=t.customer_name,
|
||||
status=t.status,
|
||||
created_at=t.created_at,
|
||||
)
|
||||
for t in tickets
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/active", response_model=list[WalkSessionResponse])
|
||||
async def list_active_sessions(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
"""The caller's currently-active sessions (for the dashboard 'Resume in progress' widget)."""
|
||||
stmt = (
|
||||
select(L1WalkSession)
|
||||
.where(L1WalkSession.created_by_user_id == user.id)
|
||||
.where(L1WalkSession.status == "active")
|
||||
.order_by(L1WalkSession.last_step_at.desc())
|
||||
.limit(20)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return [_to_response(s) for s in result.scalars()]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=WalkSessionResponse)
|
||||
async def get_session(
|
||||
session_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
session = await _get_session_or_404(db, session_id, user)
|
||||
return _to_response(session)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/step", response_model=WalkSessionResponse)
|
||||
async def post_step(
|
||||
session_id: UUID,
|
||||
payload: StepRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
await _get_session_or_404(db, session_id, user)
|
||||
try:
|
||||
updated = await l1_session_service.record_step(
|
||||
db,
|
||||
session_id=session_id,
|
||||
node_id=payload.node_id,
|
||||
question=payload.question,
|
||||
answer=payload.answer,
|
||||
note=payload.note,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
||||
await db.commit()
|
||||
return _to_response(updated)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/notes", response_model=WalkSessionResponse)
|
||||
async def post_notes(
|
||||
session_id: UUID,
|
||||
payload: NotesRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
await _get_session_or_404(db, session_id, user)
|
||||
try:
|
||||
updated = await l1_session_service.update_notes(
|
||||
db,
|
||||
session_id=session_id,
|
||||
notes=payload.notes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
||||
await db.commit()
|
||||
return _to_response(updated)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/resolve", response_model=WalkSessionResponse)
|
||||
async def post_resolve(
|
||||
session_id: UUID,
|
||||
payload: ResolveRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
await _get_session_or_404(db, session_id, user)
|
||||
try:
|
||||
updated = await l1_session_service.resolve(
|
||||
db,
|
||||
session_id=session_id,
|
||||
helpful=payload.helpful,
|
||||
resolution_notes=payload.resolution_notes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
||||
await db.commit()
|
||||
return _to_response(updated)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/escalate", response_model=WalkSessionResponse)
|
||||
async def post_escalate(
|
||||
session_id: UUID,
|
||||
payload: EscalateRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
await _get_session_or_404(db, session_id, user)
|
||||
try:
|
||||
updated = await l1_session_service.escalate(
|
||||
db,
|
||||
session_id=session_id,
|
||||
reason=payload.reason or "",
|
||||
reason_category=payload.reason_category,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
||||
await db.commit()
|
||||
return _to_response(updated)
|
||||
|
||||
|
||||
@router.post("/escalate-without-walk", response_model=WalkSessionResponse)
|
||||
async def post_escalate_without_walk(
|
||||
payload: EscalateWithoutWalkRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
user: Annotated[User, Depends(require_l1_or_coverage)],
|
||||
):
|
||||
ticket = await internal_ticket_service.create_ticket(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
created_by_user_id=user.id,
|
||||
problem_statement=payload.problem_statement,
|
||||
customer_name=payload.customer_name,
|
||||
customer_contact=payload.customer_contact,
|
||||
)
|
||||
session = await l1_session_service.escalate_without_walk(
|
||||
db,
|
||||
account_id=user.account_id,
|
||||
user=user,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
reason_category=payload.reason_category,
|
||||
reason=payload.reason,
|
||||
)
|
||||
await db.commit()
|
||||
return _to_response(session)
|
||||
@@ -3,7 +3,7 @@ import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -118,6 +118,29 @@ async def _sign_in_or_register(
|
||||
|
||||
if is_new_user:
|
||||
if invite_record is not None:
|
||||
# Seat enforcement: re-check at OAuth accept time (race-condition guard).
|
||||
if invite_record.role in ("engineer", "l1_tech"):
|
||||
from app.core.subscriptions import get_account_subscription
|
||||
from app.services.seat_enforcement import check_seat_available
|
||||
sub = await get_account_subscription(invite_record.account_id, db)
|
||||
if sub is not None:
|
||||
acct_result = await db.execute(
|
||||
select(Account).where(Account.id == invite_record.account_id)
|
||||
)
|
||||
acct = acct_result.scalar_one()
|
||||
seat_result = await check_seat_available(acct, sub, invite_record.role, db)
|
||||
if not seat_result.available:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"code": "seat_limit_exceeded",
|
||||
"role": seat_result.role,
|
||||
"current": seat_result.current,
|
||||
"limit": seat_result.limit,
|
||||
"upgrade_url": "/account/billing",
|
||||
},
|
||||
)
|
||||
|
||||
# Join the invited account directly — no personal account, no
|
||||
# trial creation.
|
||||
user = User(
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.api.deps import (
|
||||
from app.api.endpoints import (
|
||||
admin,
|
||||
admin_audit,
|
||||
l1,
|
||||
admin_categories,
|
||||
admin_dashboard,
|
||||
admin_feature_flags,
|
||||
@@ -185,3 +186,6 @@ api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_branches.router, dependencies=_pro_deps)
|
||||
api_router.include_router(session_handoffs.router, dependencies=_pro_deps)
|
||||
api_router.include_router(device_types.router, dependencies=_tenant_deps)
|
||||
# L1 is a separate seat-counted SKU; subscription gating is enforced by
|
||||
# seat_enforcement (engineer + l1_seat_limit), not require_active_subscription.
|
||||
api_router.include_router(l1.router, dependencies=_tenant_deps)
|
||||
|
||||
@@ -13,13 +13,20 @@ async def log_audit(
|
||||
resource_id: Optional[UUID] = None,
|
||||
details: Optional[dict] = None,
|
||||
account_id: Optional[UUID] = None,
|
||||
acting_as: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
||||
"""Record an audit log entry. Does not commit — caller's commit picks it up.
|
||||
|
||||
acting_as: optional tag from the session (e.g. 'l1_coverage' for engineers
|
||||
on the L1 surface, None for native l1_tech users).
|
||||
"""
|
||||
if account_id is None:
|
||||
# Derive from the acting user's account as a fallback (one extra query).
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
result = await db.execute(select(User.account_id).where(User.id == user_id))
|
||||
result = await db.execute(
|
||||
select(User.account_id).where(User.id == user_id)
|
||||
)
|
||||
account_id = result.scalar_one()
|
||||
|
||||
entry = AuditLog(
|
||||
@@ -29,5 +36,6 @@ async def log_audit(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
details=details,
|
||||
acting_as=acting_as,
|
||||
)
|
||||
db.add(entry)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
Centralized permission checks for ResolutionFlow.
|
||||
|
||||
Role hierarchy: super_admin > owner > engineer > viewer
|
||||
Role hierarchy: super_admin > owner > engineer > l1_tech > viewer
|
||||
|
||||
- super_admin: is_super_admin=True, full system access
|
||||
- owner: account_role='owner', manage account resources
|
||||
- engineer: account_role='engineer' (default), CRUD own trees/steps
|
||||
- l1_tech: account_role='l1_tech', use /l1/* surface only — walk flows, resolve/escalate
|
||||
- viewer: account_role='viewer', read-only (can browse, run sessions, rate steps)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -23,7 +24,8 @@ ROLE_HIERARCHY = {
|
||||
"super_admin": 4,
|
||||
"owner": 3,
|
||||
"engineer": 2,
|
||||
"viewer": 1,
|
||||
"l1_tech": 1,
|
||||
"viewer": 0,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -221,6 +221,18 @@ async def lifespan(app: FastAPI):
|
||||
max_instances=1,
|
||||
)
|
||||
|
||||
# L1 walk session cleanup: flip stale active sessions to 'abandoned' (hourly)
|
||||
from app.services.l1_session_cleanup import run_cleanup_job as l1_cleanup_run
|
||||
scheduler.add_job(
|
||||
l1_cleanup_run,
|
||||
trigger="interval",
|
||||
hours=1,
|
||||
id="l1_session_cleanup",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
args=[async_session_maker],
|
||||
)
|
||||
|
||||
# Auto-seed trees in background on PR environments
|
||||
seed_task = None
|
||||
if settings.SEED_ON_DEPLOY:
|
||||
|
||||
@@ -66,6 +66,8 @@ from .oauth_identity import OAuthIdentity # noqa: F401
|
||||
from .plan_billing import PlanBilling # noqa: F401
|
||||
from .sales_lead import SalesLead # noqa: F401
|
||||
from .stripe_event import StripeEvent # noqa: F401
|
||||
from .internal_ticket import InternalTicket # noqa: F401
|
||||
from .l1_walk_session import L1WalkSession # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -146,4 +148,6 @@ __all__ = [
|
||||
"PlanBilling",
|
||||
"SalesLead",
|
||||
"StripeEvent",
|
||||
"InternalTicket",
|
||||
"L1WalkSession",
|
||||
]
|
||||
|
||||
@@ -57,6 +57,11 @@ class Account(Base):
|
||||
team_size_bucket: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
primary_psa: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
|
||||
# L1 workspace seats
|
||||
l1_seats_purchased: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default="0"
|
||||
)
|
||||
|
||||
# SSO / SAML groundwork (Task 11)
|
||||
sso_enabled: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
|
||||
sso_provider: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) # "saml" | "oidc"
|
||||
|
||||
@@ -35,6 +35,7 @@ class AuditLog(Base):
|
||||
)
|
||||
details: Mapped[Optional[dict]] = mapped_column(JSONB, nullable=True)
|
||||
ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True)
|
||||
acting_as: Mapped[Optional[str]] = mapped_column(String(30), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey, Integer, Float, CheckConstraint
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey, Integer, Float, Boolean, CheckConstraint, text as sa_text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
@@ -48,6 +48,14 @@ class FlowProposal(Base):
|
||||
"status IN ('pending', 'approved', 'modified', 'rejected', 'dismissed', 'auto_reinforced')",
|
||||
name="ck_flow_proposals_status",
|
||||
),
|
||||
CheckConstraint(
|
||||
"source IN ('ai_realtime_l1', 'kb_accelerator', 'manual_draft', 'ai_promoted')",
|
||||
name="ck_flow_proposals_source",
|
||||
),
|
||||
CheckConstraint(
|
||||
"linked_ticket_kind IS NULL OR linked_ticket_kind IN ('psa', 'internal')",
|
||||
name="ck_flow_proposals_linked_ticket_kind",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
@@ -135,6 +143,16 @@ class FlowProposal(Base):
|
||||
comment="The flow that was created/updated when this proposal was approved",
|
||||
)
|
||||
|
||||
# ── L1 workspace ──
|
||||
source: Mapped[str] = mapped_column(
|
||||
String(30), nullable=False, server_default=sa_text("'manual_draft'"),
|
||||
)
|
||||
linked_ticket_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
linked_ticket_kind: Mapped[Optional[str]] = mapped_column(String(10), nullable=True)
|
||||
validated_by_outcome: Mapped[bool] = mapped_column(
|
||||
Boolean(), nullable=False, server_default=sa_text('false'),
|
||||
)
|
||||
|
||||
# ── Timestamps ──
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
|
||||
117
backend/app/models/internal_ticket.py
Normal file
117
backend/app/models/internal_ticket.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Internal ticket model.
|
||||
|
||||
Fallback ticket table for L1 intake when the account has no PSA integration.
|
||||
Tracks the customer-facing problem, resolution lifecycle, and optional links
|
||||
to a flow, flow proposal, AI session, and assigned engineer.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey, CheckConstraint
|
||||
from sqlalchemy import text as sa_text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.models.flow_proposal import FlowProposal
|
||||
from app.models.ai_session import AISession
|
||||
|
||||
|
||||
class InternalTicket(Base):
|
||||
"""A fallback support ticket for accounts without a PSA integration.
|
||||
|
||||
status lifecycle:
|
||||
- open: Submitted, not yet picked up.
|
||||
- walking: L1 technician is actively walking the flow.
|
||||
- resolved: Issue resolved; resolution_notes captured.
|
||||
- escalated: Could not resolve; requires higher-tier intervention.
|
||||
"""
|
||||
__tablename__ = "internal_tickets"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"status IN ('open', 'walking', 'resolved', 'escalated')",
|
||||
name="ck_internal_tickets_status",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# ── Customer info ──
|
||||
customer_name: Mapped[Optional[str]] = mapped_column(String(120), nullable=True)
|
||||
customer_contact: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)
|
||||
problem_statement: Mapped[str] = mapped_column(Text(), nullable=False)
|
||||
|
||||
# ── Lifecycle ──
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(30), nullable=False, server_default=sa_text("'open'"), index=True,
|
||||
)
|
||||
|
||||
# ── Optional links ──
|
||||
flow_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
flow_proposal_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("flow_proposals.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
ai_session_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("ai_sessions.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
assigned_user_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# ── Resolution ──
|
||||
resolution_notes: Mapped[Optional[str]] = mapped_column(Text(), nullable=True)
|
||||
psa_promoted_ticket_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(64), nullable=True,
|
||||
comment="External PSA ticket ID when this ticket is promoted to a PSA system",
|
||||
)
|
||||
|
||||
# ── Timestamps ──
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True,
|
||||
)
|
||||
|
||||
# ── Relationships ──
|
||||
account: Mapped["Account"] = relationship("Account")
|
||||
created_by: Mapped["User"] = relationship("User", foreign_keys=[created_by_user_id])
|
||||
assigned_user: Mapped[Optional["User"]] = relationship("User", foreign_keys=[assigned_user_id])
|
||||
flow: Mapped[Optional["Tree"]] = relationship("Tree")
|
||||
flow_proposal: Mapped[Optional["FlowProposal"]] = relationship("FlowProposal")
|
||||
ai_session: Mapped[Optional["AISession"]] = relationship("AISession")
|
||||
141
backend/app/models/l1_walk_session.py
Normal file
141
backend/app/models/l1_walk_session.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""L1 walk session model.
|
||||
|
||||
Per-session state for an L1 technician walking a ticket through a flow,
|
||||
flow proposal, or ad-hoc investigation. Tracks the walked path, notes
|
||||
captured at each step, and terminal resolution / escalation metadata.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import String, Text, DateTime, Boolean, ForeignKey, CheckConstraint
|
||||
from sqlalchemy import text as sa_text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.models.flow_proposal import FlowProposal
|
||||
|
||||
|
||||
class L1WalkSession(Base):
|
||||
"""A single L1 technician session walking a ticket.
|
||||
|
||||
session_kind values:
|
||||
- flow: Walking a published flow (flow_id required, flow_proposal_id null).
|
||||
- proposal: Walking a draft flow proposal (flow_proposal_id required, flow_id null).
|
||||
- adhoc: Free-form investigation (both flow_id and flow_proposal_id null).
|
||||
|
||||
status lifecycle:
|
||||
- active: Session is in progress.
|
||||
- resolved: Issue resolved; resolution_notes captured.
|
||||
- escalated: Could not resolve; escalation_reason captured.
|
||||
- abandoned: Session exited without resolution or explicit escalation.
|
||||
"""
|
||||
|
||||
__tablename__ = "l1_walk_sessions"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"ticket_kind IN ('psa', 'internal')",
|
||||
name="ck_l1_walk_sessions_ticket_kind",
|
||||
),
|
||||
CheckConstraint(
|
||||
"session_kind IN ('flow', 'proposal', 'adhoc')",
|
||||
name="ck_l1_walk_sessions_session_kind",
|
||||
),
|
||||
CheckConstraint(
|
||||
"status IN ('active', 'resolved', 'escalated', 'abandoned')",
|
||||
name="ck_l1_walk_sessions_status",
|
||||
),
|
||||
CheckConstraint(
|
||||
"(session_kind = 'flow' AND flow_id IS NOT NULL AND flow_proposal_id IS NULL) "
|
||||
"OR (session_kind = 'proposal' AND flow_proposal_id IS NOT NULL AND flow_id IS NULL) "
|
||||
"OR (session_kind = 'adhoc' AND flow_id IS NULL AND flow_proposal_id IS NULL)",
|
||||
name="ck_l1_walk_sessions_target_consistency",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by_user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# ── Actor context ──
|
||||
acting_as: Mapped[Optional[str]] = mapped_column(String(30), nullable=True)
|
||||
|
||||
# ── Ticket reference ──
|
||||
ticket_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
ticket_kind: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
|
||||
# ── Session kind + target ──
|
||||
session_kind: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
flow_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
flow_proposal_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("flow_proposals.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# ── Navigation state ──
|
||||
current_node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
walked_path: Mapped[list[dict[str, Any]]] = mapped_column(
|
||||
JSONB(), nullable=False, server_default=sa_text("'[]'::jsonb"),
|
||||
)
|
||||
walk_notes: Mapped[list[dict[str, Any]]] = mapped_column(
|
||||
JSONB(), nullable=False, server_default=sa_text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
# ── Lifecycle ──
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, server_default=sa_text("'active'"), index=True,
|
||||
)
|
||||
|
||||
# ── Resolution ──
|
||||
resolution_notes: Mapped[Optional[str]] = mapped_column(Text(), nullable=True)
|
||||
helpful: Mapped[Optional[bool]] = mapped_column(Boolean(), nullable=True)
|
||||
|
||||
# ── Escalation ──
|
||||
escalation_reason: Mapped[Optional[str]] = mapped_column(Text(), nullable=True)
|
||||
escalation_reason_category: Mapped[Optional[str]] = mapped_column(
|
||||
String(30), nullable=True,
|
||||
)
|
||||
|
||||
# ── Timestamps ──
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
last_step_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
index=True,
|
||||
)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True,
|
||||
)
|
||||
|
||||
# ── Relationships ──
|
||||
account: Mapped["Account"] = relationship("Account")
|
||||
created_by: Mapped["User"] = relationship("User", foreign_keys=[created_by_user_id])
|
||||
flow: Mapped[Optional["Tree"]] = relationship("Tree")
|
||||
flow_proposal: Mapped[Optional["FlowProposal"]] = relationship("FlowProposal")
|
||||
@@ -21,6 +21,7 @@ class Subscription(Base):
|
||||
billing_interval: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="active")
|
||||
seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
l1_seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
current_period_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
current_period_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, CheckConstraint, Text, Integer
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, CheckConstraint, Text, Integer, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
@@ -22,7 +22,7 @@ class User(Base):
|
||||
name='ck_users_role_enum'
|
||||
),
|
||||
CheckConstraint(
|
||||
"account_role IN ('owner', 'admin', 'engineer', 'viewer')",
|
||||
"account_role IN ('owner', 'admin', 'engineer', 'l1_tech', 'viewer')",
|
||||
name='ck_users_account_role_enum'
|
||||
),
|
||||
)
|
||||
@@ -50,6 +50,9 @@ class User(Base):
|
||||
index=True
|
||||
)
|
||||
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
|
||||
can_cover_l1: Mapped[bool] = mapped_column(
|
||||
Boolean(), nullable=False, server_default=text('false')
|
||||
)
|
||||
|
||||
# Legacy team columns (kept for PR A coexistence)
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
|
||||
@@ -27,7 +27,7 @@ class TransferOwnershipRequest(BaseModel):
|
||||
|
||||
class AccountInviteCreate(BaseModel):
|
||||
email: str = Field(..., max_length=255)
|
||||
role: str = Field("engineer", pattern="^(engineer|viewer)$")
|
||||
role: str = Field("engineer", pattern="^(engineer|viewer|l1_tech)$")
|
||||
expires_in_days: Optional[int] = Field(None, ge=1, le=30)
|
||||
|
||||
|
||||
|
||||
72
backend/app/schemas/l1.py
Normal file
72
backend/app/schemas/l1.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Pydantic schemas for the /l1/* endpoint surface."""
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IntakeRequest(BaseModel):
|
||||
problem_statement: str = Field(..., min_length=1)
|
||||
customer_name: Optional[str] = None
|
||||
customer_contact: Optional[str] = None
|
||||
flow_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class IntakeResponse(BaseModel):
|
||||
session_id: UUID
|
||||
session_kind: Literal["flow", "proposal", "adhoc"]
|
||||
ticket_id: str
|
||||
ticket_kind: Literal["psa", "internal"]
|
||||
|
||||
|
||||
class StepRequest(BaseModel):
|
||||
node_id: str
|
||||
question: str
|
||||
answer: str
|
||||
note: Optional[str] = None
|
||||
|
||||
|
||||
class NotesRequest(BaseModel):
|
||||
notes: list[dict[str, Any]]
|
||||
|
||||
|
||||
class ResolveRequest(BaseModel):
|
||||
helpful: bool
|
||||
resolution_notes: str
|
||||
|
||||
|
||||
class EscalateRequest(BaseModel):
|
||||
reason: Optional[str] = None
|
||||
reason_category: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class EscalateWithoutWalkRequest(BaseModel):
|
||||
problem_statement: str = Field(..., min_length=1)
|
||||
customer_name: Optional[str] = None
|
||||
customer_contact: Optional[str] = None
|
||||
reason_category: str = Field(..., min_length=1)
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class WalkSessionResponse(BaseModel):
|
||||
id: UUID
|
||||
session_kind: str
|
||||
flow_id: Optional[UUID]
|
||||
flow_proposal_id: Optional[UUID]
|
||||
current_node_id: Optional[str]
|
||||
walked_path: list[dict[str, Any]]
|
||||
walk_notes: list[dict[str, Any]]
|
||||
status: str
|
||||
started_at: datetime
|
||||
last_step_at: datetime
|
||||
resolved_at: Optional[datetime]
|
||||
|
||||
|
||||
class QueueRow(BaseModel):
|
||||
ticket_id: str
|
||||
ticket_kind: Literal["psa", "internal"]
|
||||
problem_statement: Optional[str] = None
|
||||
customer_name: Optional[str] = None
|
||||
status: str
|
||||
created_at: Optional[datetime] = None
|
||||
18
backend/app/schemas/seat_enforcement.py
Normal file
18
backend/app/schemas/seat_enforcement.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
Role = Literal['engineer', 'l1_tech']
|
||||
|
||||
|
||||
class SeatCheckResult(BaseModel):
|
||||
available: bool
|
||||
current: int
|
||||
limit: Optional[int] # None = unlimited
|
||||
role: Role
|
||||
|
||||
|
||||
class SeatUsage(BaseModel):
|
||||
engineer: SeatCheckResult
|
||||
l1_tech: SeatCheckResult
|
||||
@@ -60,6 +60,7 @@ class UserResponse(UserBase):
|
||||
email_verified_at: Optional[datetime] = None
|
||||
onboarding_step_completed: Optional[int] = None
|
||||
onboarding_dismissed: bool = False
|
||||
can_cover_l1: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -72,4 +73,8 @@ class RoleUpdate(BaseModel):
|
||||
class AccountRoleUpdate(BaseModel):
|
||||
# Ownership changes must go through the explicit transfer-ownership flow so
|
||||
# account.owner_id stays consistent with user.account_role.
|
||||
account_role: str = Field(..., pattern="^(admin|engineer|viewer)$")
|
||||
account_role: str = Field(..., pattern="^(admin|engineer|viewer|l1_tech)$")
|
||||
|
||||
|
||||
class CoverageUpdate(BaseModel):
|
||||
can_cover_l1: bool
|
||||
|
||||
90
backend/app/services/internal_ticket_service.py
Normal file
90
backend/app/services/internal_ticket_service.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""CRUD + status transitions for internal_tickets (the no-PSA fallback ticket model)."""
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.internal_ticket import InternalTicket
|
||||
|
||||
|
||||
async def create_ticket(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
created_by_user_id: UUID,
|
||||
problem_statement: str,
|
||||
customer_name: Optional[str] = None,
|
||||
customer_contact: Optional[str] = None,
|
||||
) -> InternalTicket:
|
||||
"""Create a new internal ticket in 'open' status."""
|
||||
ticket = InternalTicket(
|
||||
account_id=account_id,
|
||||
created_by_user_id=created_by_user_id,
|
||||
problem_statement=problem_statement,
|
||||
customer_name=customer_name,
|
||||
customer_contact=customer_contact,
|
||||
)
|
||||
db.add(ticket)
|
||||
await db.flush()
|
||||
return ticket
|
||||
|
||||
|
||||
async def update_status(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
ticket_id: UUID,
|
||||
status: str,
|
||||
resolution_notes: Optional[str] = None,
|
||||
assigned_user_id: Optional[UUID] = None,
|
||||
) -> InternalTicket:
|
||||
"""Transition a ticket to a new status. Sets resolved_at when status='resolved'."""
|
||||
ticket = await db.get(InternalTicket, ticket_id)
|
||||
if not ticket:
|
||||
raise ValueError(f"InternalTicket {ticket_id} not found")
|
||||
ticket.status = status
|
||||
if status == 'resolved':
|
||||
ticket.resolved_at = datetime.now(timezone.utc)
|
||||
if resolution_notes is not None:
|
||||
ticket.resolution_notes = resolution_notes
|
||||
if assigned_user_id is not None:
|
||||
ticket.assigned_user_id = assigned_user_id
|
||||
await db.flush()
|
||||
return ticket
|
||||
|
||||
|
||||
async def get_ticket(db: AsyncSession, *, ticket_id: UUID) -> Optional[InternalTicket]:
|
||||
"""Fetch a ticket by ID. Returns None if not found."""
|
||||
return await db.get(InternalTicket, ticket_id)
|
||||
|
||||
|
||||
async def list_tickets_for_account(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> list[InternalTicket]:
|
||||
"""List tickets for an account, optionally filtered by status, newest first."""
|
||||
stmt = select(InternalTicket).where(InternalTicket.account_id == account_id)
|
||||
if status:
|
||||
stmt = stmt.where(InternalTicket.status == status)
|
||||
stmt = stmt.order_by(InternalTicket.created_at.desc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars())
|
||||
|
||||
|
||||
async def promote_to_psa(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
ticket_id: UUID,
|
||||
psa_ticket_id: str,
|
||||
) -> InternalTicket:
|
||||
"""Mark an internal ticket as promoted to PSA."""
|
||||
ticket = await db.get(InternalTicket, ticket_id)
|
||||
if not ticket:
|
||||
raise ValueError(f"InternalTicket {ticket_id} not found")
|
||||
ticket.psa_promoted_ticket_id = psa_ticket_id
|
||||
await db.flush()
|
||||
return ticket
|
||||
49
backend/app/services/l1_session_cleanup.py
Normal file
49
backend/app/services/l1_session_cleanup.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Hourly cleanup job: flip stale active L1WalkSessions to 'abandoned'.
|
||||
|
||||
Sessions with status='active' and last_step_at older than 24h are considered
|
||||
abandoned (L1 closed the browser, customer hung up, etc.). Flipping them
|
||||
removes them from the "Resume in progress" widget while preserving the row
|
||||
for audit/reporting.
|
||||
|
||||
Run via APScheduler interval job, max_instances=1 (Lesson 1).
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.l1_walk_session import L1WalkSession
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def flip_stale_sessions(db: AsyncSession) -> int:
|
||||
"""Flip active sessions to 'abandoned' if last_step_at < now - 24h.
|
||||
|
||||
Returns the number of sessions flipped.
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
stmt = (
|
||||
update(L1WalkSession)
|
||||
.where(L1WalkSession.status == "active")
|
||||
.where(L1WalkSession.last_step_at < cutoff)
|
||||
.values(status="abandoned")
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
return result.rowcount or 0
|
||||
|
||||
|
||||
async def run_cleanup_job(session_factory) -> None:
|
||||
"""APScheduler entry point. Uses the admin session factory (no RLS context)."""
|
||||
async with session_factory() as db:
|
||||
try:
|
||||
count = await flip_stale_sessions(db)
|
||||
if count > 0:
|
||||
logger.info(
|
||||
"l1_session_cleanup: flipped %d sessions to abandoned", count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("l1_session_cleanup: error during run")
|
||||
321
backend/app/services/l1_session_service.py
Normal file
321
backend/app/services/l1_session_service.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""L1 session lifecycle: start (flow/proposal/adhoc), step, notes, resolve, escalate.
|
||||
|
||||
start_* functions live in T12; step/notes are T13; resolve/escalate are T14.
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.audit import log_audit
|
||||
from app.models.flow_proposal import FlowProposal
|
||||
from app.models.l1_walk_session import L1WalkSession
|
||||
from app.models.user import User
|
||||
from app.services import internal_ticket_service
|
||||
|
||||
|
||||
def _resolve_acting_as(user: User) -> Optional[str]:
|
||||
"""An engineer (whether covering or not) gets tagged for audit when using L1 surface.
|
||||
|
||||
Returns 'l1_coverage' for engineers (only engineers WITH the coverage flag should
|
||||
reach this code path — the require_l1_or_coverage dep gates that). For native
|
||||
l1_tech users, returns None (no special tag — they ARE l1).
|
||||
"""
|
||||
if user.account_role == "engineer":
|
||||
return "l1_coverage"
|
||||
return None
|
||||
|
||||
|
||||
async def start_flow_session(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
user: User,
|
||||
flow_id: UUID,
|
||||
ticket_id: str,
|
||||
ticket_kind: str, # 'psa' | 'internal'
|
||||
) -> L1WalkSession:
|
||||
"""Start a session walking an authored flow."""
|
||||
session = L1WalkSession(
|
||||
account_id=account_id,
|
||||
created_by_user_id=user.id,
|
||||
acting_as=_resolve_acting_as(user),
|
||||
ticket_id=ticket_id,
|
||||
ticket_kind=ticket_kind,
|
||||
session_kind="flow",
|
||||
flow_id=flow_id,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def start_proposal_session(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
user: User,
|
||||
flow_proposal_id: UUID,
|
||||
ticket_id: str,
|
||||
ticket_kind: str,
|
||||
) -> L1WalkSession:
|
||||
"""Start a session walking an AI-built FlowProposal."""
|
||||
session = L1WalkSession(
|
||||
account_id=account_id,
|
||||
created_by_user_id=user.id,
|
||||
acting_as=_resolve_acting_as(user),
|
||||
ticket_id=ticket_id,
|
||||
ticket_kind=ticket_kind,
|
||||
session_kind="proposal",
|
||||
flow_proposal_id=flow_proposal_id,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def start_adhoc_session(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
user: User,
|
||||
ticket_id: str,
|
||||
ticket_kind: str,
|
||||
) -> L1WalkSession:
|
||||
"""Start an ad-hoc session with no tree (free-form note-taking only)."""
|
||||
session = L1WalkSession(
|
||||
account_id=account_id,
|
||||
created_by_user_id=user.id,
|
||||
acting_as=_resolve_acting_as(user),
|
||||
ticket_id=ticket_id,
|
||||
ticket_kind=ticket_kind,
|
||||
session_kind="adhoc",
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def record_step(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: UUID,
|
||||
node_id: str,
|
||||
question: str,
|
||||
answer: str,
|
||||
note: Optional[str] = None,
|
||||
) -> L1WalkSession:
|
||||
"""Record an answered step in a tree walk. Appends to walked_path JSONB and
|
||||
advances current_node_id. Raises ValueError on adhoc sessions or inactive
|
||||
sessions. Updates last_step_at."""
|
||||
session = await db.get(L1WalkSession, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"L1WalkSession {session_id} not found")
|
||||
if session.session_kind == "adhoc":
|
||||
raise ValueError("Cannot record step on adhoc session — use update_notes")
|
||||
if session.status != "active":
|
||||
raise ValueError(f"Session {session_id} is not active (status={session.status})")
|
||||
entry = {
|
||||
"node_id": node_id,
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"l1_note": note,
|
||||
}
|
||||
# JSONB requires assigning a new list — in-place mutation isn't tracked
|
||||
session.walked_path = [*session.walked_path, entry]
|
||||
session.current_node_id = node_id
|
||||
session.last_step_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def update_notes(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: UUID,
|
||||
notes: list[dict],
|
||||
) -> L1WalkSession:
|
||||
"""Replace walk_notes on an active session. Used by adhoc walks for
|
||||
debounced autosave. Raises ValueError if missing or inactive. Caps notes
|
||||
payload at 256KB to prevent unbounded growth."""
|
||||
session = await db.get(L1WalkSession, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"L1WalkSession {session_id} not found")
|
||||
if session.status != "active":
|
||||
raise ValueError(f"Session {session_id} is not active (status={session.status})")
|
||||
encoded_size = len(json.dumps(notes).encode("utf-8"))
|
||||
if encoded_size > 256 * 1024:
|
||||
raise ValueError("walk_notes exceeds 256KB cap — consider escalating")
|
||||
session.walk_notes = notes
|
||||
session.last_step_at = datetime.now(timezone.utc)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def resolve(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: UUID,
|
||||
helpful: bool,
|
||||
resolution_notes: str,
|
||||
) -> L1WalkSession:
|
||||
"""Close a session as resolved.
|
||||
|
||||
- Sets status='resolved', helpful, resolution_notes, resolved_at.
|
||||
- On helpful=True AND session_kind='proposal': flips
|
||||
flow_proposal.validated_by_outcome=True (one-bit aggregate signal).
|
||||
- Closes the linked internal ticket (PSA close stubbed for Phase 2).
|
||||
- Raises ValueError on missing or non-active session.
|
||||
"""
|
||||
session = await db.get(L1WalkSession, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"L1WalkSession {session_id} not found")
|
||||
if session.status != "active":
|
||||
raise ValueError(f"Session not active (status={session.status})")
|
||||
now = datetime.now(timezone.utc)
|
||||
session.status = "resolved"
|
||||
session.helpful = helpful
|
||||
session.resolution_notes = resolution_notes
|
||||
session.resolved_at = now
|
||||
session.last_step_at = now
|
||||
|
||||
if helpful and session.session_kind == "proposal" and session.flow_proposal_id:
|
||||
proposal = await db.get(FlowProposal, session.flow_proposal_id)
|
||||
if proposal:
|
||||
proposal.validated_by_outcome = True
|
||||
|
||||
if session.ticket_kind == "internal":
|
||||
await internal_ticket_service.update_status(
|
||||
db,
|
||||
ticket_id=UUID(session.ticket_id),
|
||||
status="resolved",
|
||||
resolution_notes=resolution_notes,
|
||||
)
|
||||
# PSA close deferred to Phase 2 — no-op for now
|
||||
|
||||
await log_audit(
|
||||
db,
|
||||
user_id=session.created_by_user_id,
|
||||
action="l1.session.resolve",
|
||||
resource_type="l1_walk_session",
|
||||
resource_id=session.id,
|
||||
details={
|
||||
"session_kind": session.session_kind,
|
||||
"helpful": helpful,
|
||||
"ticket_id": session.ticket_id,
|
||||
"ticket_kind": session.ticket_kind,
|
||||
},
|
||||
account_id=session.account_id,
|
||||
acting_as=session.acting_as,
|
||||
)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def escalate(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: UUID,
|
||||
reason: str,
|
||||
reason_category: str,
|
||||
) -> L1WalkSession:
|
||||
"""Escalate an active session to engineering.
|
||||
|
||||
- Sets status='escalated', escalation_reason, escalation_reason_category, resolved_at.
|
||||
- Marks the linked internal ticket as escalated (PSA reassign deferred to Phase 2).
|
||||
- Raises ValueError on missing or non-active session.
|
||||
"""
|
||||
session = await db.get(L1WalkSession, session_id)
|
||||
if not session:
|
||||
raise ValueError(f"L1WalkSession {session_id} not found")
|
||||
if session.status != "active":
|
||||
raise ValueError(f"Session not active (status={session.status})")
|
||||
now = datetime.now(timezone.utc)
|
||||
session.status = "escalated"
|
||||
session.escalation_reason = reason
|
||||
session.escalation_reason_category = reason_category
|
||||
session.resolved_at = now
|
||||
session.last_step_at = now
|
||||
|
||||
if session.ticket_kind == "internal":
|
||||
await internal_ticket_service.update_status(
|
||||
db,
|
||||
ticket_id=UUID(session.ticket_id),
|
||||
status="escalated",
|
||||
)
|
||||
# PSA reassign deferred to Phase 2
|
||||
|
||||
await log_audit(
|
||||
db,
|
||||
user_id=session.created_by_user_id,
|
||||
action="l1.session.escalate",
|
||||
resource_type="l1_walk_session",
|
||||
resource_id=session.id,
|
||||
details={
|
||||
"session_kind": session.session_kind,
|
||||
"escalation_reason_category": reason_category,
|
||||
"ticket_id": session.ticket_id,
|
||||
"ticket_kind": session.ticket_kind,
|
||||
},
|
||||
account_id=session.account_id,
|
||||
acting_as=session.acting_as,
|
||||
)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
async def escalate_without_walk(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: UUID,
|
||||
user: User,
|
||||
ticket_id: str,
|
||||
ticket_kind: str,
|
||||
reason_category: str,
|
||||
reason: Optional[str] = None,
|
||||
) -> L1WalkSession:
|
||||
"""Create an immediately-escalated session with no walked_path.
|
||||
|
||||
Used from the BuildAbortedNoKB screen (no KB content available to walk a
|
||||
tree). Captures the call as an audit record + escalates the ticket without
|
||||
requiring a walker session in between.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
session = L1WalkSession(
|
||||
account_id=account_id,
|
||||
created_by_user_id=user.id,
|
||||
acting_as=_resolve_acting_as(user),
|
||||
ticket_id=ticket_id,
|
||||
ticket_kind=ticket_kind,
|
||||
session_kind="adhoc",
|
||||
status="escalated",
|
||||
escalation_reason=reason,
|
||||
escalation_reason_category=reason_category,
|
||||
resolved_at=now,
|
||||
last_step_at=now,
|
||||
)
|
||||
db.add(session)
|
||||
if ticket_kind == "internal":
|
||||
await internal_ticket_service.update_status(
|
||||
db,
|
||||
ticket_id=UUID(ticket_id),
|
||||
status="escalated",
|
||||
)
|
||||
await db.flush() # flush first so session.id is populated
|
||||
await log_audit(
|
||||
db,
|
||||
user_id=session.created_by_user_id,
|
||||
action="l1.session.escalate_no_walk",
|
||||
resource_type="l1_walk_session",
|
||||
resource_id=session.id,
|
||||
details={
|
||||
"escalation_reason_category": reason_category,
|
||||
"ticket_id": ticket_id,
|
||||
"ticket_kind": ticket_kind,
|
||||
},
|
||||
account_id=session.account_id,
|
||||
acting_as=session.acting_as,
|
||||
)
|
||||
return session
|
||||
63
backend/app/services/seat_enforcement.py
Normal file
63
backend/app/services/seat_enforcement.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.seat_enforcement import SeatCheckResult
|
||||
|
||||
|
||||
Role = Literal['engineer', 'l1_tech']
|
||||
|
||||
|
||||
def _limit_for_role(subscription: Subscription, role: Role) -> int | None:
|
||||
if role == 'engineer':
|
||||
return subscription.seat_limit
|
||||
if role == 'l1_tech':
|
||||
return subscription.l1_seat_limit
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
|
||||
async def check_seat_available(
|
||||
account: Account,
|
||||
subscription: Subscription,
|
||||
role: Role,
|
||||
db: AsyncSession,
|
||||
) -> SeatCheckResult:
|
||||
"""
|
||||
Count active users with the given role in the account, compare against
|
||||
the role-specific seat limit on the subscription. Returns availability.
|
||||
|
||||
None limit = unlimited (returns available=True).
|
||||
"""
|
||||
limit = _limit_for_role(subscription, role)
|
||||
|
||||
stmt = (
|
||||
select(func.count(User.id))
|
||||
.where(User.account_id == account.id)
|
||||
.where(User.account_role == role)
|
||||
.where(User.is_active.is_(True))
|
||||
)
|
||||
current = (await db.execute(stmt)).scalar_one()
|
||||
|
||||
if limit is None:
|
||||
return SeatCheckResult(available=True, current=current, limit=None, role=role)
|
||||
return SeatCheckResult(
|
||||
available=current < limit,
|
||||
current=current,
|
||||
limit=limit,
|
||||
role=role,
|
||||
)
|
||||
|
||||
|
||||
async def get_seat_usage(
|
||||
account: Account,
|
||||
subscription: Subscription,
|
||||
db: AsyncSession,
|
||||
) -> tuple[SeatCheckResult, SeatCheckResult]:
|
||||
"""Return (engineer, l1_tech) seat-usage tuple for the seat-counter widget."""
|
||||
eng = await check_seat_available(account, subscription, 'engineer', db)
|
||||
l1 = await check_seat_available(account, subscription, 'l1_tech', db)
|
||||
return eng, l1
|
||||
@@ -2,11 +2,13 @@
|
||||
"""
|
||||
Create test user accounts for local development.
|
||||
|
||||
Creates 4 accounts:
|
||||
1. Super Admin – platform-wide admin (manages everything)
|
||||
2. Pro Solo User – single user on a "pro" plan
|
||||
3. Team Admin – admin of a team account ("team" plan)
|
||||
4. Team Engineer – regular engineer on the same team account
|
||||
Creates 6 accounts:
|
||||
1. Super Admin – platform-wide admin (manages everything)
|
||||
2. Pro Solo User – single user on a "pro" plan
|
||||
3. Team Admin – admin of a team account ("team" plan)
|
||||
4. Team Engineer – regular engineer on the same team account
|
||||
5. L1 Tech – l1_tech role on the Acme MSP team (E2E: L1 happy path)
|
||||
6. Coverage Engineer – engineer with can_cover_l1=True (E2E: coverage banner)
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
@@ -71,6 +73,29 @@ USERS = [
|
||||
"account_name": "Acme MSP", # same shared account
|
||||
"account_role": "engineer",
|
||||
"plan": None, # uses the team_admin's account & subscription
|
||||
"can_cover_l1": False,
|
||||
},
|
||||
{
|
||||
"key": "l1_tech",
|
||||
"name": "Lee L1Tech",
|
||||
"email": "l1@resolutionflow.example.com",
|
||||
"is_super_admin": False,
|
||||
"is_team_admin": False,
|
||||
"account_name": "Acme MSP", # same shared account as team_admin
|
||||
"account_role": "l1_tech",
|
||||
"plan": None, # uses the team_admin's account & subscription
|
||||
"can_cover_l1": False,
|
||||
},
|
||||
{
|
||||
"key": "coverage_engineer",
|
||||
"name": "Casey Coverage",
|
||||
"email": "engineer-coverage@resolutionflow.example.com",
|
||||
"is_super_admin": False,
|
||||
"is_team_admin": False,
|
||||
"account_name": "Acme MSP", # same shared account as team_admin
|
||||
"account_role": "engineer",
|
||||
"plan": None, # uses the team_admin's account & subscription
|
||||
"can_cover_l1": True,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -114,7 +139,9 @@ async def main() -> None:
|
||||
continue
|
||||
|
||||
# ---- Create or reuse Account ----
|
||||
if cfg["key"] == "team_engineer":
|
||||
# Users that share the Acme MSP account (no own account to create)
|
||||
_acme_members = {"team_engineer", "l1_tech", "coverage_engineer"}
|
||||
if cfg["key"] in _acme_members:
|
||||
if team_account_id is None:
|
||||
result = await conn.execute(
|
||||
text("SELECT id FROM accounts WHERE name = :name"),
|
||||
@@ -145,13 +172,14 @@ async def main() -> None:
|
||||
# 7-day verification grace immediately. Without this, fixtures hit
|
||||
# require_verified_email_after_grace once their created_at ages past
|
||||
# 7 days and get walled out of protected routes.
|
||||
can_cover_l1 = cfg.get("can_cover_l1", False)
|
||||
await conn.execute(
|
||||
text("""
|
||||
INSERT INTO users (id, email, password_hash, name, role, is_super_admin,
|
||||
is_team_admin, is_active, account_id, account_role,
|
||||
created_at, email_verified_at)
|
||||
can_cover_l1, created_at, email_verified_at)
|
||||
VALUES (:id, :email, :pw, :name, 'engineer', :is_sa, :is_ta, true,
|
||||
:account_id, :account_role, :now, :now)
|
||||
:account_id, :account_role, :can_cover_l1, :now, :now)
|
||||
"""),
|
||||
{
|
||||
"id": user_id,
|
||||
@@ -162,12 +190,13 @@ async def main() -> None:
|
||||
"is_ta": cfg["is_team_admin"],
|
||||
"account_id": account_id,
|
||||
"account_role": cfg["account_role"],
|
||||
"can_cover_l1": can_cover_l1,
|
||||
"now": now,
|
||||
},
|
||||
)
|
||||
|
||||
# Set account owner (skip for team_engineer — they don't own the account)
|
||||
if cfg["key"] != "team_engineer":
|
||||
# Set account owner (skip for shared-account members — they don't own the account)
|
||||
if cfg["key"] not in _acme_members:
|
||||
await conn.execute(
|
||||
text("UPDATE accounts SET owner_id = :uid WHERE id = :aid"),
|
||||
{"uid": user_id, "aid": account_id},
|
||||
@@ -183,7 +212,8 @@ async def main() -> None:
|
||||
{"id": uuid.uuid4(), "aid": account_id, "plan": cfg["plan"], "now": now},
|
||||
)
|
||||
|
||||
print(f" [OK] {cfg['email']:40s} account_role={cfg['account_role']:<10s} plan={cfg['plan'] or '(shared)'}")
|
||||
cover_flag = " [can_cover_l1]" if can_cover_l1 else ""
|
||||
print(f" [OK] {cfg['email']:40s} account_role={cfg['account_role']:<12s} plan={cfg['plan'] or '(shared)'}{cover_flag}")
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@@ -194,10 +224,12 @@ async def main() -> None:
|
||||
print("=" * 60)
|
||||
print()
|
||||
print(" Accounts:")
|
||||
print(f" Super Admin : admin@resolutionflow.example.com")
|
||||
print(f" Pro Solo : pro@resolutionflow.example.com")
|
||||
print(f" Team Admin : teamadmin@resolutionflow.example.com")
|
||||
print(f" Team Engineer: engineer@resolutionflow.example.com")
|
||||
print(f" Super Admin : admin@resolutionflow.example.com")
|
||||
print(f" Pro Solo : pro@resolutionflow.example.com")
|
||||
print(f" Team Admin : teamadmin@resolutionflow.example.com")
|
||||
print(f" Team Engineer : engineer@resolutionflow.example.com")
|
||||
print(f" L1 Tech : l1@resolutionflow.example.com")
|
||||
print(f" Coverage Engineer : engineer-coverage@resolutionflow.example.com")
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ assert "test" in _test_db_name, (
|
||||
)
|
||||
|
||||
_RUN_RLS_TESTS = os.environ.get("RUN_RLS_TESTS") == "1"
|
||||
_RLS_ISOLATION_FILE = "test_rls_isolation.py"
|
||||
_RLS_TEST_FILES = {"test_rls_isolation.py", "test_l1_rls.py"}
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
@@ -117,7 +117,9 @@ def pytest_collection_modifyitems(config, items):
|
||||
deselected = []
|
||||
for item in items:
|
||||
item_path = getattr(item, "path", None) or getattr(item, "fspath", None)
|
||||
if item_path and str(item_path).endswith(_RLS_ISOLATION_FILE):
|
||||
if item_path and any(
|
||||
str(item_path).endswith(f) for f in _RLS_TEST_FILES
|
||||
):
|
||||
deselected.append(item)
|
||||
else:
|
||||
selected.append(item)
|
||||
|
||||
99
backend/tests/test_deps_l1.py
Normal file
99
backend/tests/test_deps_l1.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Unit tests for L1-related dependency guards.
|
||||
|
||||
Uses MagicMock user objects — no database required.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.deps import require_l1, require_l1_or_coverage, require_l1_or_above
|
||||
|
||||
|
||||
def _make_user(account_role="engineer", is_super_admin=False, can_cover_l1=False):
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.account_role = account_role
|
||||
user.is_super_admin = is_super_admin
|
||||
user.can_cover_l1 = can_cover_l1
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_l1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_require_l1_passes_for_l1_tech():
|
||||
user = _make_user(account_role="l1_tech")
|
||||
result = await require_l1(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_passes_for_super_admin():
|
||||
user = _make_user(account_role="owner", is_super_admin=True)
|
||||
result = await require_l1(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_blocks_engineer():
|
||||
user = _make_user(account_role="engineer")
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await require_l1(current_user=user)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_l1_or_coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_require_l1_or_coverage_passes_l1_tech():
|
||||
user = _make_user(account_role="l1_tech")
|
||||
result = await require_l1_or_coverage(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_or_coverage_passes_engineer_with_flag():
|
||||
user = _make_user(account_role="engineer", can_cover_l1=True)
|
||||
result = await require_l1_or_coverage(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_or_coverage_blocks_engineer_without_flag():
|
||||
user = _make_user(account_role="engineer", can_cover_l1=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await require_l1_or_coverage(current_user=user)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
async def test_require_l1_or_coverage_passes_owner_always():
|
||||
user = _make_user(account_role="owner")
|
||||
result = await require_l1_or_coverage(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_l1_or_above
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_require_l1_or_above_passes_engineer():
|
||||
user = _make_user(account_role="engineer")
|
||||
result = await require_l1_or_above(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_or_above_passes_l1_tech():
|
||||
user = _make_user(account_role="l1_tech")
|
||||
result = await require_l1_or_above(current_user=user)
|
||||
assert result is user
|
||||
|
||||
|
||||
async def test_require_l1_or_above_blocks_viewer():
|
||||
user = _make_user(account_role="viewer")
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await require_l1_or_above(current_user=user)
|
||||
assert exc.value.status_code == 403
|
||||
182
backend/tests/test_internal_ticket_service.py
Normal file
182
backend/tests/test_internal_ticket_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Unit + integration tests for internal_ticket_service."""
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.services.internal_ticket_service import (
|
||||
create_ticket, update_status, get_ticket,
|
||||
list_tickets_for_account, promote_to_psa,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _make_account(db: AsyncSession) -> Account:
|
||||
s = str(uuid.uuid4())[:8]
|
||||
account = Account(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Test Account {s}",
|
||||
display_code=s[:8],
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
return account
|
||||
|
||||
|
||||
async def _make_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: uuid.UUID,
|
||||
role: str = "l1_tech",
|
||||
) -> User:
|
||||
s = str(uuid.uuid4())[:8]
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"user-{s}@example.com",
|
||||
name=f"User {s}",
|
||||
account_id=account_id,
|
||||
account_role=role,
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_ticket_sets_status_open(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await create_ticket(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
created_by_user_id=l1.id,
|
||||
problem_statement="Outlook can't connect",
|
||||
customer_name="Alice",
|
||||
)
|
||||
assert ticket.status == 'open'
|
||||
assert ticket.account_id == account.id
|
||||
assert ticket.customer_name == "Alice"
|
||||
assert ticket.created_by_user_id == l1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_to_resolved_sets_resolved_at(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await create_ticket(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
created_by_user_id=l1.id,
|
||||
problem_statement="Test",
|
||||
)
|
||||
assert ticket.resolved_at is None
|
||||
updated = await update_status(
|
||||
test_db,
|
||||
ticket_id=ticket.id,
|
||||
status='resolved',
|
||||
resolution_notes="Fixed via reboot",
|
||||
)
|
||||
assert updated.status == 'resolved'
|
||||
assert updated.resolved_at is not None
|
||||
assert updated.resolution_notes == "Fixed via reboot"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_to_escalated_does_not_set_resolved_at(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await create_ticket(
|
||||
test_db, account_id=account.id, created_by_user_id=l1.id,
|
||||
problem_statement="x",
|
||||
)
|
||||
updated = await update_status(test_db, ticket_id=ticket.id, status='escalated')
|
||||
assert updated.status == 'escalated'
|
||||
assert updated.resolved_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_assigns_user(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
engineer = await _make_user(test_db, account_id=account.id, role="engineer")
|
||||
ticket = await create_ticket(
|
||||
test_db, account_id=account.id, created_by_user_id=l1.id,
|
||||
problem_statement="x",
|
||||
)
|
||||
updated = await update_status(
|
||||
test_db, ticket_id=ticket.id, status='escalated',
|
||||
assigned_user_id=engineer.id,
|
||||
)
|
||||
assert updated.assigned_user_id == engineer.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ticket_returns_none_for_missing_id(test_db: AsyncSession):
|
||||
result = await get_ticket(test_db, ticket_id=uuid.uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tickets_filters_by_account(test_db: AsyncSession):
|
||||
account_a = await _make_account(test_db)
|
||||
account_b = await _make_account(test_db)
|
||||
l1_a = await _make_user(test_db, account_id=account_a.id)
|
||||
l1_b = await _make_user(test_db, account_id=account_b.id)
|
||||
ticket_a = await create_ticket(
|
||||
test_db, account_id=account_a.id, created_by_user_id=l1_a.id,
|
||||
problem_statement="A",
|
||||
)
|
||||
ticket_b = await create_ticket(
|
||||
test_db, account_id=account_b.id, created_by_user_id=l1_b.id,
|
||||
problem_statement="B",
|
||||
)
|
||||
rows = await list_tickets_for_account(test_db, account_id=account_a.id)
|
||||
ids = [r.id for r in rows]
|
||||
assert ticket_a.id in ids
|
||||
assert ticket_b.id not in ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tickets_filters_by_status(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
open_t = await create_ticket(
|
||||
test_db, account_id=account.id, created_by_user_id=l1.id,
|
||||
problem_statement="open",
|
||||
)
|
||||
resolved_t = await create_ticket(
|
||||
test_db, account_id=account.id, created_by_user_id=l1.id,
|
||||
problem_statement="r",
|
||||
)
|
||||
await update_status(test_db, ticket_id=resolved_t.id, status='resolved')
|
||||
open_rows = await list_tickets_for_account(test_db, account_id=account.id, status='open')
|
||||
assert open_t.id in [r.id for r in open_rows]
|
||||
assert resolved_t.id not in [r.id for r in open_rows]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_promote_to_psa_sets_external_id(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await create_ticket(
|
||||
test_db, account_id=account.id, created_by_user_id=l1.id,
|
||||
problem_statement="x",
|
||||
)
|
||||
updated = await promote_to_psa(test_db, ticket_id=ticket.id, psa_ticket_id="CW-12345")
|
||||
assert updated.psa_promoted_ticket_id == "CW-12345"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_raises_for_missing_ticket(test_db: AsyncSession):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await update_status(test_db, ticket_id=uuid.uuid4(), status='resolved')
|
||||
564
backend/tests/test_invite_seat_enforcement.py
Normal file
564
backend/tests/test_invite_seat_enforcement.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""Integration tests for seat enforcement at invite create, accept-invite, and
|
||||
role-change endpoints.
|
||||
|
||||
All tests use the `client` + `test_db` fixtures from conftest, which spin up
|
||||
a fresh schema per test and wire the ASGI app to the test DB.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-local helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _register(client: AsyncClient, *, email: str, password: str = "TestPassword123!", name: str = "Test User") -> dict:
|
||||
resp = await client.post("/api/v1/auth/register", json={"email": email, "password": password, "name": name})
|
||||
assert resp.status_code in (200, 201), resp.text
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _login(client: AsyncClient, *, email: str, password: str = "TestPassword123!") -> dict:
|
||||
resp = await client.post("/api/v1/auth/login/json", json={"email": email, "password": password})
|
||||
assert resp.status_code == 200, resp.text
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
async def _set_sub(db: AsyncSession, account_id: uuid.UUID, *, seat_limit: int | None, l1_seat_limit: int | None = None) -> None:
|
||||
"""Replace the account's subscription with specified limits."""
|
||||
await db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||
db.add(Subscription(
|
||||
account_id=account_id,
|
||||
plan="pro",
|
||||
status="active",
|
||||
seat_limit=seat_limit,
|
||||
l1_seat_limit=l1_seat_limit,
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _add_member(db: AsyncSession, account_id: uuid.UUID, *, role: str, suffix: str | None = None) -> User:
|
||||
"""Directly insert an active user with the given role into the account."""
|
||||
s = suffix or str(uuid.uuid4())[:8]
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"member-{s}@example.com",
|
||||
name=f"Member {s}",
|
||||
account_id=account_id,
|
||||
account_role=role,
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Invite create — single invite endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_engineer_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /me/invites → 402 when engineer seat limit is exhausted."""
|
||||
owner = await _register(client, email="owner1@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner1@example.com")
|
||||
|
||||
# seat_limit=1, already 1 engineer → full
|
||||
await _set_sub(test_db, account_id, seat_limit=1)
|
||||
# The owner registers as engineer, but is actually 'owner' role — add a separate engineer
|
||||
await _add_member(test_db, account_id, role="engineer")
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "new-eng@example.com", "role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
assert body["detail"]["role"] == "engineer"
|
||||
assert body["detail"]["current"] == 1
|
||||
assert body["detail"]["limit"] == 1
|
||||
assert "upgrade_url" in body["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_l1_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /me/invites → 402 when l1_tech seat limit is exhausted."""
|
||||
owner = await _register(client, email="owner2@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner2@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=10, l1_seat_limit=1)
|
||||
await _add_member(test_db, account_id, role="l1_tech")
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "new-l1@example.com", "role": "l1_tech"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
assert body["detail"]["role"] == "l1_tech"
|
||||
assert body["detail"]["current"] == 1
|
||||
assert body["detail"]["limit"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /me/invites → 201 when engineer seats have room."""
|
||||
owner = await _register(client, email="owner3@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner3@example.com")
|
||||
|
||||
# seat_limit=5, 0 engineers → plenty of room
|
||||
await _set_sub(test_db, account_id, seat_limit=5)
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "new-eng2@example.com", "role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_viewer_bypasses_seat_check(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /me/invites → 201 for viewer role even when engineer seats full."""
|
||||
owner = await _register(client, email="owner4@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner4@example.com")
|
||||
|
||||
# engineer seats exhausted — should not affect viewer invites
|
||||
await _set_sub(test_db, account_id, seat_limit=1)
|
||||
await _add_member(test_db, account_id, role="engineer")
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "viewer@example.com", "role": "viewer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_unlimited_seat_limit_always_succeeds(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /me/invites → 201 when seat_limit is None (unlimited)."""
|
||||
owner = await _register(client, email="owner5@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner5@example.com")
|
||||
|
||||
# seat_limit=None = unlimited
|
||||
await _set_sub(test_db, account_id, seat_limit=None)
|
||||
# add many engineers
|
||||
for i in range(5):
|
||||
await _add_member(test_db, account_id, role="engineer", suffix=f"bulk{i}")
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "new-unlimited@example.com", "role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_invite_per_row_402_preserves_structured_detail(client: AsyncClient, test_db: AsyncSession):
|
||||
"""Bulk invite returns 200 overall; rows that hit the seat limit appear in the
|
||||
`failed` list with structured detail (not a stringified repr)."""
|
||||
owner = await _register(client, email="owner_bulk@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner_bulk@example.com")
|
||||
|
||||
# seat_limit=1, already 1 engineer → next engineer invite fails
|
||||
await _set_sub(test_db, account_id, seat_limit=1)
|
||||
await _add_member(test_db, account_id, role="engineer")
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites/bulk",
|
||||
json={"invites": [
|
||||
{"email": "viewer-ok@example.com", "role": "viewer"},
|
||||
{"email": "eng-blocked@example.com", "role": "engineer"},
|
||||
]},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code in (200, 201), resp.text
|
||||
body = resp.json()
|
||||
assert len(body["created"]) == 1
|
||||
assert body["created"][0]["email"] == "viewer-ok@example.com"
|
||||
assert len(body["failed"]) == 1
|
||||
failed_row = body["failed"][0]
|
||||
assert failed_row["email"] == "eng-blocked@example.com"
|
||||
# Structured detail preserved (dict, not repr string)
|
||||
assert isinstance(failed_row["error"], dict)
|
||||
assert failed_row["error"]["code"] == "seat_limit_exceeded"
|
||||
assert failed_row["error"]["role"] == "engineer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_grandfathered_account_blocks_new_invites(client: AsyncClient, test_db: AsyncSession):
|
||||
"""Grandfathering: existing over-seated account keeps existing users but
|
||||
new engineer invites are still blocked (current > limit → blocked)."""
|
||||
owner = await _register(client, email="owner6@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner6@example.com")
|
||||
|
||||
# current=3 engineers > seat_limit=2 (over-seated / grandfathered)
|
||||
await _set_sub(test_db, account_id, seat_limit=2)
|
||||
for i in range(3):
|
||||
await _add_member(test_db, account_id, role="engineer", suffix=f"gf{i}")
|
||||
|
||||
# New invite must be blocked
|
||||
resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "one-more@example.com", "role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
# current (3) > limit (2) — forward enforcement fires, existing users unaffected
|
||||
assert body["detail"]["current"] == 3
|
||||
assert body["detail"]["limit"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accept-invite race condition — auth.py register path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invite_blocked_when_seats_full_at_accept_time(client: AsyncClient, test_db: AsyncSession):
|
||||
"""Race-condition guard: invite created when seats available, but by
|
||||
accept time someone else consumed the last seat → 402."""
|
||||
# Step 1: create an owner and send an invite
|
||||
owner = await _register(client, email="owner7@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
owner_headers = await _login(client, email="owner7@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=2)
|
||||
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "race@example.com", "role": "engineer"},
|
||||
headers=owner_headers,
|
||||
)
|
||||
assert invite_resp.status_code == 201, invite_resp.text
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Step 2: fill the seats after the invite was created (race condition)
|
||||
await _add_member(test_db, account_id, role="engineer", suffix="race1")
|
||||
await _add_member(test_db, account_id, role="engineer", suffix="race2")
|
||||
|
||||
# Step 3: invitee tries to register — should get 402
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "race@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Race User",
|
||||
"account_invite_code": invite_code,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invite_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession):
|
||||
"""Normal accept-invite path works when seats have room."""
|
||||
owner = await _register(client, email="owner8@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
owner_headers = await _login(client, email="owner8@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=5)
|
||||
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "acceptme@example.com", "role": "engineer"},
|
||||
headers=owner_headers,
|
||||
)
|
||||
assert invite_resp.status_code == 201, invite_resp.text
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "acceptme@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Accept User",
|
||||
"account_invite_code": invite_code,
|
||||
},
|
||||
)
|
||||
assert resp.status_code in (200, 201), resp.text
|
||||
assert resp.json()["account_id"] == str(account_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Role-change endpoint — PATCH /me/members/{user_id}/role
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_change_viewer_to_engineer_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH /me/members/{id}/role → 402 when promoting viewer → engineer and seats full."""
|
||||
owner = await _register(client, email="owner9@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner9@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=1)
|
||||
# Fill the engineer seat
|
||||
await _add_member(test_db, account_id, role="engineer")
|
||||
# Add a viewer to promote
|
||||
viewer = await _add_member(test_db, account_id, role="viewer")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{viewer.id}/role",
|
||||
json={"account_role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
assert body["detail"]["role"] == "engineer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_change_viewer_to_l1_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH /me/members/{id}/role → 402 when promoting viewer → l1_tech and l1 seats full."""
|
||||
owner = await _register(client, email="owner10@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner10@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=10, l1_seat_limit=1)
|
||||
await _add_member(test_db, account_id, role="l1_tech")
|
||||
viewer = await _add_member(test_db, account_id, role="viewer")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{viewer.id}/role",
|
||||
json={"account_role": "l1_tech"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 402, resp.text
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "seat_limit_exceeded"
|
||||
assert body["detail"]["role"] == "l1_tech"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_change_promotion_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH /me/members/{id}/role → 200 when seats are available."""
|
||||
owner = await _register(client, email="owner11@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner11@example.com")
|
||||
|
||||
await _set_sub(test_db, account_id, seat_limit=5)
|
||||
viewer = await _add_member(test_db, account_id, role="viewer")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{viewer.id}/role",
|
||||
json={"account_role": "engineer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.json()["account_role"] == "engineer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_change_demotion_bypasses_seat_check(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH /me/members/{id}/role → 200 for demotions even when seats full."""
|
||||
owner = await _register(client, email="owner12@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner12@example.com")
|
||||
|
||||
# Seats full — but demotion should still succeed
|
||||
await _set_sub(test_db, account_id, seat_limit=1)
|
||||
engineer = await _add_member(test_db, account_id, role="engineer")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{engineer.id}/role",
|
||||
json={"account_role": "viewer"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.json()["account_role"] == "viewer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /me/seats — seat counter widget endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_seats_returns_both_role_counts(client: AsyncClient, test_db: AsyncSession):
|
||||
"""GET /accounts/me/seats returns engineer + l1_tech seat usage."""
|
||||
owner = await _register(client, email="owner_seats@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner_seats@example.com")
|
||||
await _set_sub(test_db, account_id, seat_limit=5, l1_seat_limit=3)
|
||||
# Add 2 engineers and 1 l1_tech as members
|
||||
for i in range(2):
|
||||
await _add_member(test_db, account_id, role="engineer", suffix=f"e{i}")
|
||||
await _add_member(test_db, account_id, role="l1_tech", suffix="l1")
|
||||
|
||||
resp = await client.get("/api/v1/accounts/me/seats", headers=headers)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["engineer"]["role"] == "engineer"
|
||||
assert body["engineer"]["current"] == 2
|
||||
assert body["engineer"]["limit"] == 5
|
||||
assert body["engineer"]["available"] is True
|
||||
assert body["l1_tech"]["role"] == "l1_tech"
|
||||
assert body["l1_tech"]["current"] == 1
|
||||
assert body["l1_tech"]["limit"] == 3
|
||||
assert body["l1_tech"]["available"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_seats_blocked_for_viewer(client: AsyncClient, test_db: AsyncSession):
|
||||
"""GET /accounts/me/seats → 403 for viewer role (engineer+ required)."""
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
# Register an owner for the account
|
||||
owner = await _register(client, email="owner_seats2@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
await _set_sub(test_db, account_id, seat_limit=5, l1_seat_limit=3)
|
||||
|
||||
# Create a viewer user with a known password directly in the DB
|
||||
viewer_password = "ViewerPass123!"
|
||||
viewer = User(
|
||||
id=uuid.uuid4(),
|
||||
email="viewer_seats@example.com",
|
||||
name="Viewer Seats",
|
||||
account_id=account_id,
|
||||
account_role="viewer",
|
||||
role="engineer", # system role field (default)
|
||||
is_active=True,
|
||||
password_hash=get_password_hash(viewer_password),
|
||||
)
|
||||
test_db.add(viewer)
|
||||
await test_db.commit()
|
||||
|
||||
# Log in as the viewer
|
||||
viewer_headers = await _login(client, email="viewer_seats@example.com", password=viewer_password)
|
||||
|
||||
resp = await client.get("/api/v1/accounts/me/seats", headers=viewer_headers)
|
||||
assert resp.status_code == 403, resp.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /me/members/{user_id}/coverage — engineer L1-coverage flag
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coverage_owner_can_toggle_engineer(client: AsyncClient, test_db: AsyncSession):
|
||||
"""Owner can set can_cover_l1=True on an engineer; response reflects new value."""
|
||||
owner = await _register(client, email="owner_cov1@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner_cov1@example.com")
|
||||
|
||||
engineer = await _add_member(test_db, account_id, role="engineer", suffix="cov1")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{engineer.id}/coverage",
|
||||
json={"can_cover_l1": True},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["can_cover_l1"] is True
|
||||
|
||||
# Toggle back to False
|
||||
resp2 = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{engineer.id}/coverage",
|
||||
json={"can_cover_l1": False},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp2.status_code == 200, resp2.text
|
||||
assert resp2.json()["can_cover_l1"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coverage_non_owner_is_forbidden(client: AsyncClient, test_db: AsyncSession):
|
||||
"""A non-owner engineer cannot toggle coverage on themselves or others."""
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
owner = await _register(client, email="owner_cov2@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
|
||||
# Create an engineer with a known password
|
||||
eng_password = "EngPass123!"
|
||||
engineer = User(
|
||||
id=uuid.uuid4(),
|
||||
email="eng_cov2@example.com",
|
||||
name="Eng Cov2",
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
password_hash=get_password_hash(eng_password),
|
||||
)
|
||||
test_db.add(engineer)
|
||||
await test_db.commit()
|
||||
|
||||
eng_headers = await _login(client, email="eng_cov2@example.com", password=eng_password)
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{engineer.id}/coverage",
|
||||
json={"can_cover_l1": True},
|
||||
headers=eng_headers,
|
||||
)
|
||||
assert resp.status_code == 403, resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coverage_viewer_role_returns_422(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH coverage on a viewer → 422 (coverage flag only applies to engineers)."""
|
||||
owner = await _register(client, email="owner_cov3@example.com")
|
||||
account_id = uuid.UUID(owner["account_id"])
|
||||
headers = await _login(client, email="owner_cov3@example.com")
|
||||
|
||||
viewer = await _add_member(test_db, account_id, role="viewer", suffix="cov3")
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{viewer.id}/coverage",
|
||||
json={"can_cover_l1": True},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 422, resp.text
|
||||
assert "engineer" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coverage_cross_account_returns_404(client: AsyncClient, test_db: AsyncSession):
|
||||
"""PATCH coverage on a user from a different account → 404 (tenancy isolation)."""
|
||||
# Account A
|
||||
owner_a = await _register(client, email="owner_cov_a@example.com")
|
||||
account_a_id = uuid.UUID(owner_a["account_id"])
|
||||
headers_a = await _login(client, email="owner_cov_a@example.com")
|
||||
|
||||
# Account B — a separate registration creates a new account
|
||||
owner_b = await _register(client, email="owner_cov_b@example.com")
|
||||
account_b_id = uuid.UUID(owner_b["account_id"])
|
||||
|
||||
# Add an engineer to account B
|
||||
engineer_b = await _add_member(test_db, account_b_id, role="engineer", suffix="covb")
|
||||
|
||||
# Owner of account A tries to patch account B's engineer — must 404
|
||||
resp = await client.patch(
|
||||
f"/api/v1/accounts/me/members/{engineer_b.id}/coverage",
|
||||
json={"can_cover_l1": True},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 404, resp.text
|
||||
362
backend/tests/test_l1_endpoints.py
Normal file
362
backend/tests/test_l1_endpoints.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Integration tests for the /l1/* endpoint surface (Task 15).
|
||||
|
||||
All tests use the `client` + `test_db` fixtures from conftest.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.models.l1_walk_session import L1WalkSession
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-local helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _register(client: AsyncClient, *, email: str, password: str = "TestPassword123!", name: str = "Test User") -> dict:
|
||||
resp = await client.post("/api/v1/auth/register", json={"email": email, "password": password, "name": name})
|
||||
assert resp.status_code in (200, 201), resp.text
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _login(client: AsyncClient, *, email: str, password: str = "TestPassword123!") -> dict:
|
||||
resp = await client.post("/api/v1/auth/login/json", json={"email": email, "password": password})
|
||||
assert resp.status_code == 200, resp.text
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
async def _ensure_subscription(db: AsyncSession, account_id: uuid.UUID) -> None:
|
||||
"""Ensure account has an active Pro subscription."""
|
||||
await db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||
db.add(Subscription(account_id=account_id, plan="pro", status="active"))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _make_l1_user(
|
||||
client: AsyncClient,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
account_id: uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
"""Register a user, set role=l1_tech, ensure subscription.
|
||||
|
||||
If account_id is given, inserts a second user directly into that account.
|
||||
Otherwise registers a fresh user via the API (new account) and returns
|
||||
both user data and login headers.
|
||||
"""
|
||||
if account_id is None:
|
||||
user_data = await _register(client, email=email)
|
||||
uid = uuid.UUID(user_data["id"])
|
||||
acct_id = uuid.UUID(user_data["account_id"])
|
||||
# Promote to l1_tech
|
||||
from sqlalchemy import select as sa_select
|
||||
result = await db.execute(sa_select(User).where(User.id == uid))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "l1_tech"
|
||||
await db.commit()
|
||||
await _ensure_subscription(db, acct_id)
|
||||
headers = await _login(client, email=email)
|
||||
return {"user_data": user_data, "headers": headers, "account_id": acct_id}
|
||||
else:
|
||||
# Insert directly into an existing account
|
||||
s = str(uuid.uuid4())[:8]
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=email,
|
||||
name=f"L1 Tech {s}",
|
||||
account_id=account_id,
|
||||
account_role="l1_tech",
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
hashed_password="$2b$12$placeholder.placeholder.placeholder.placeholder.plac",
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
return {"user_data": {"id": str(user.id), "account_id": str(account_id)}, "headers": None}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Intake without flow_id → 200 + session_kind='adhoc'
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intake_adhoc(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/intake without flow_id creates adhoc session."""
|
||||
info = await _make_l1_user(client, test_db, email="l1intake@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Printer won't turn on", "customer_name": "Alice"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["session_kind"] == "adhoc"
|
||||
assert body["ticket_kind"] == "internal"
|
||||
assert "session_id" in body
|
||||
assert "ticket_id" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Intake without auth → 401
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intake_no_auth(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/intake without token → 401."""
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Test"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Intake as viewer → 403
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intake_viewer_forbidden(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/intake as viewer role → 403."""
|
||||
user_data = await _register(client, email="viewer_l1@example.com")
|
||||
uid = uuid.UUID(user_data["id"])
|
||||
acct_id = uuid.UUID(user_data["account_id"])
|
||||
|
||||
from sqlalchemy import select as sa_select
|
||||
result = await test_db.execute(sa_select(User).where(User.id == uid))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "viewer"
|
||||
await test_db.commit()
|
||||
await _ensure_subscription(test_db, acct_id)
|
||||
|
||||
headers = await _login(client, email="viewer_l1@example.com")
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Test"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Step on adhoc session → 400 (cannot step an adhoc)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_on_adhoc_returns_400(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/sessions/{id}/step on adhoc session → 400."""
|
||||
info = await _make_l1_user(client, test_db, email="l1step@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
# Create adhoc session via intake
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Adhoc issue"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
session_id = resp.json()["session_id"]
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/v1/l1/sessions/{session_id}/step",
|
||||
json={"node_id": "node1", "question": "Q?", "answer": "A"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "adhoc" in resp.json()["detail"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Notes on adhoc session → 200, walk_notes updated
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notes_on_adhoc_session(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/sessions/{id}/notes → 200 and walk_notes is updated."""
|
||||
info = await _make_l1_user(client, test_db, email="l1notes@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Notes test"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
session_id = resp.json()["session_id"]
|
||||
|
||||
notes_payload = [{"text": "Customer called about printer", "ts": "2026-05-28T10:00:00Z"}]
|
||||
resp = await client.post(
|
||||
f"/api/v1/l1/sessions/{session_id}/notes",
|
||||
json={"notes": notes_payload},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["walk_notes"] == notes_payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Resolve with helpful=True → 200; GET shows status=resolved
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_session(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/sessions/{id}/resolve → 200; subsequent GET shows resolved."""
|
||||
info = await _make_l1_user(client, test_db, email="l1resolve@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Resolve test"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
session_id = resp.json()["session_id"]
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/v1/l1/sessions/{session_id}/resolve",
|
||||
json={"helpful": True, "resolution_notes": "Restarted the printer."},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.json()["status"] == "resolved"
|
||||
|
||||
# GET should also show resolved
|
||||
resp = await client.get(f"/api/v1/l1/sessions/{session_id}", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "resolved"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Escalate session → 200; status=escalated
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_session(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/sessions/{id}/escalate → 200; status becomes escalated."""
|
||||
info = await _make_l1_user(client, test_db, email="l1escalate@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Escalation test"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
session_id = resp.json()["session_id"]
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/v1/l1/sessions/{session_id}/escalate",
|
||||
json={"reason_category": "needs_l2", "reason": "Beyond L1 scope"},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "escalated"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. escalate-without-walk → 200 + session in escalated status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk(client: AsyncClient, test_db: AsyncSession):
|
||||
"""POST /l1/escalate-without-walk → 200 + session.status=escalated."""
|
||||
info = await _make_l1_user(client, test_db, email="l1eww@example.com")
|
||||
headers = info["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/escalate-without-walk",
|
||||
json={
|
||||
"problem_statement": "No KB available",
|
||||
"reason_category": "no_kb",
|
||||
"reason": "No knowledge base content matched",
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "escalated"
|
||||
assert body["session_kind"] == "adhoc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. List active sessions returns L1's active sessions ordered by last_step_at DESC
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_active_sessions_ordered(client: AsyncClient, test_db: AsyncSession):
|
||||
"""GET /l1/sessions/active returns active sessions ordered by last_step_at DESC."""
|
||||
info = await _make_l1_user(client, test_db, email="l1active@example.com")
|
||||
headers = info["headers"]
|
||||
user_id = uuid.UUID(info["user_data"]["id"])
|
||||
account_id = info["account_id"]
|
||||
|
||||
# Create two sessions with controlled timestamps directly in DB
|
||||
now = datetime.now(timezone.utc)
|
||||
s1 = L1WalkSession(
|
||||
id=uuid.uuid4(),
|
||||
account_id=account_id,
|
||||
created_by_user_id=user_id,
|
||||
ticket_id=str(uuid.uuid4()),
|
||||
ticket_kind="internal",
|
||||
session_kind="adhoc",
|
||||
status="active",
|
||||
started_at=now - timedelta(minutes=10),
|
||||
last_step_at=now - timedelta(minutes=5),
|
||||
)
|
||||
s2 = L1WalkSession(
|
||||
id=uuid.uuid4(),
|
||||
account_id=account_id,
|
||||
created_by_user_id=user_id,
|
||||
ticket_id=str(uuid.uuid4()),
|
||||
ticket_kind="internal",
|
||||
session_kind="adhoc",
|
||||
status="active",
|
||||
started_at=now - timedelta(minutes=20),
|
||||
last_step_at=now - timedelta(minutes=1),
|
||||
)
|
||||
test_db.add_all([s1, s2])
|
||||
await test_db.commit()
|
||||
|
||||
resp = await client.get("/api/v1/l1/sessions/active", headers=headers)
|
||||
assert resp.status_code == 200, resp.text
|
||||
bodies = resp.json()
|
||||
ids = [b["id"] for b in bodies]
|
||||
# s2 has the more recent last_step_at → should come first
|
||||
assert ids.index(str(s2.id)) < ids.index(str(s1.id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. GET session from different account → 404 (tenancy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_cross_account_returns_404(client: AsyncClient, test_db: AsyncSession):
|
||||
"""GET /l1/sessions/{id} from a different account → 404."""
|
||||
# Account A: creates a session
|
||||
info_a = await _make_l1_user(client, test_db, email="l1tenanta@example.com")
|
||||
headers_a = info_a["headers"]
|
||||
|
||||
resp = await client.post(
|
||||
"/api/v1/l1/intake",
|
||||
json={"problem_statement": "Account A issue"},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
session_id = resp.json()["session_id"]
|
||||
|
||||
# Account B: different user in a different account
|
||||
info_b = await _make_l1_user(client, test_db, email="l1tenantb@example.com")
|
||||
headers_b = info_b["headers"]
|
||||
|
||||
resp = await client.get(f"/api/v1/l1/sessions/{session_id}", headers=headers_b)
|
||||
assert resp.status_code == 404
|
||||
450
backend/tests/test_l1_rls.py
Normal file
450
backend/tests/test_l1_rls.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# backend/tests/test_l1_rls.py
|
||||
"""
|
||||
RLS regression tests for L1 Phase 1 tables.
|
||||
|
||||
Verifies that `internal_tickets` and `l1_walk_sessions` — both with
|
||||
FORCE ROW LEVEL SECURITY + `tenant_isolation` policy on `account_id` —
|
||||
block cross-tenant reads AND reject WITH CHECK violations on INSERT.
|
||||
|
||||
Uses synchronous psycopg2 (not asyncpg) to avoid the conftest
|
||||
teardown hook that closes the asyncio event loop after every test,
|
||||
which is incompatible with module-scoped asyncpg fixtures.
|
||||
|
||||
Run with:
|
||||
RUN_RLS_TESTS=1 DB_APP_ROLE_PASSWORD=app_secret_change_me \
|
||||
pytest tests/test_l1_rls.py -v --override-ini="addopts="
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote, urlsplit
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.rls
|
||||
|
||||
_DATABASE_TEST_URL = os.getenv(
|
||||
"DATABASE_TEST_URL",
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/resolutionflow_test",
|
||||
)
|
||||
_DATABASE_TEST_URL_SYNC = _DATABASE_TEST_URL.replace(
|
||||
"postgresql+asyncpg://",
|
||||
"postgresql://",
|
||||
1,
|
||||
)
|
||||
_TEST_DB_PARTS = urlsplit(_DATABASE_TEST_URL_SYNC)
|
||||
|
||||
_DB_HOST = os.getenv(
|
||||
"TEST_DB_HOST", _TEST_DB_PARTS.hostname or "localhost"
|
||||
)
|
||||
_DB_PORT = int(os.getenv(
|
||||
"TEST_DB_PORT", str(_TEST_DB_PARTS.port or 5432)
|
||||
))
|
||||
_DB_NAME = os.getenv(
|
||||
"TEST_DB_NAME",
|
||||
unquote(_TEST_DB_PARTS.path.lstrip("/") or "resolutionflow_test"),
|
||||
)
|
||||
_ADMIN_USER = os.getenv(
|
||||
"TEST_DB_ADMIN_USER",
|
||||
unquote(_TEST_DB_PARTS.username or "postgres"),
|
||||
)
|
||||
_ADMIN_PASSWORD = os.getenv(
|
||||
"TEST_DB_ADMIN_PASSWORD",
|
||||
unquote(_TEST_DB_PARTS.password or "postgres"),
|
||||
)
|
||||
_APP_PASSWORD = os.getenv("DB_APP_ROLE_PASSWORD", "app_secret_change_me")
|
||||
|
||||
ACCOUNT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
|
||||
|
||||
def _admin_dsn() -> dict:
|
||||
return dict(
|
||||
host=_DB_HOST, port=_DB_PORT, dbname=_DB_NAME,
|
||||
user=_ADMIN_USER, password=_ADMIN_PASSWORD,
|
||||
)
|
||||
|
||||
|
||||
def _app_dsn() -> dict:
|
||||
return dict(
|
||||
host=_DB_HOST, port=_DB_PORT, dbname=_DB_NAME,
|
||||
user="resolutionflow_app", password=_APP_PASSWORD,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def _ensure_rls_schema():
|
||||
"""Re-apply Alembic migrations so that RLS policies are present.
|
||||
|
||||
The standard test_db fixture uses Base.metadata.create_all which skips
|
||||
RLS setup. Running 'alembic upgrade head' against the test DB ensures
|
||||
the FORCE ROW LEVEL SECURITY + tenant_isolation policies created in the
|
||||
L1 migrations (T5/T6) are active.
|
||||
|
||||
We drop and recreate the public schema first so that any tables left behind
|
||||
by a prior create_all-based test_db run don't conflict with alembic's
|
||||
migration tracking (alembic would see existing tables without alembic_version
|
||||
and fail with DuplicateTable errors).
|
||||
"""
|
||||
# Drop and recreate the schema to ensure a clean slate for alembic.
|
||||
with psycopg2.connect(**_admin_dsn()) as conn:
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DROP SCHEMA public CASCADE")
|
||||
cur.execute("CREATE SCHEMA public")
|
||||
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = _DATABASE_TEST_URL
|
||||
env["DATABASE_URL_SYNC"] = _DATABASE_TEST_URL_SYNC
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "upgrade", "head"],
|
||||
cwd=backend_dir,
|
||||
env=env,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed fixture (module-scoped, synchronous psycopg2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def l1_rls_seed(_ensure_rls_schema):
|
||||
"""Insert two accounts, two users, one internal_ticket and one
|
||||
l1_walk_session per account using a superuser (BYPASSRLS) connection.
|
||||
|
||||
Returns a dict with the seeded IDs so tests can reference them.
|
||||
Cleans up on module teardown.
|
||||
"""
|
||||
conn = psycopg2.connect(**_admin_dsn())
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
|
||||
# Accounts (idempotent — shared with test_rls_isolation.py)
|
||||
cur.execute(
|
||||
"INSERT INTO accounts (id, name, display_code, created_at, updated_at)"
|
||||
" VALUES (%s, %s, %s, NOW(), NOW()),"
|
||||
" (%s, %s, %s, NOW(), NOW())"
|
||||
" ON CONFLICT (id) DO NOTHING",
|
||||
(
|
||||
ACCOUNT_A_ID, "L1 RLS Tenant A", "RLSA0001",
|
||||
ACCOUNT_B_ID, "L1 RLS Tenant B", "RLSB0001",
|
||||
),
|
||||
)
|
||||
|
||||
user_a_tmp = str(uuid.uuid4())
|
||||
user_b_tmp = str(uuid.uuid4())
|
||||
cur.execute(
|
||||
"INSERT INTO users"
|
||||
" (id, email, password_hash, name, role,"
|
||||
" is_super_admin, is_team_admin, is_service_account, must_change_password,"
|
||||
" is_active, account_id, account_role, timezone, created_at)"
|
||||
" VALUES"
|
||||
" (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()),"
|
||||
" (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())"
|
||||
" ON CONFLICT (email) DO NOTHING",
|
||||
(
|
||||
user_a_tmp, "l1-rls-a@example.com", "placeholder",
|
||||
"L1 RLS User A", "engineer",
|
||||
False, False, False, False,
|
||||
True, ACCOUNT_A_ID, "engineer", "UTC",
|
||||
user_b_tmp, "l1-rls-b@example.com", "placeholder",
|
||||
"L1 RLS User B", "engineer",
|
||||
False, False, False, False,
|
||||
True, ACCOUNT_B_ID, "engineer", "UTC",
|
||||
),
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"SELECT id FROM users WHERE email = 'l1-rls-a@example.com'"
|
||||
)
|
||||
user_a_id = str(cur.fetchone()[0])
|
||||
cur.execute(
|
||||
"SELECT id FROM users WHERE email = 'l1-rls-b@example.com'"
|
||||
)
|
||||
user_b_id = str(cur.fetchone()[0])
|
||||
|
||||
ticket_a_id = str(uuid.uuid4())
|
||||
ticket_b_id = str(uuid.uuid4())
|
||||
walk_a_id = str(uuid.uuid4())
|
||||
walk_b_id = str(uuid.uuid4())
|
||||
|
||||
cur.execute(
|
||||
"INSERT INTO internal_tickets"
|
||||
" (id, account_id, created_by_user_id, problem_statement,"
|
||||
" status, created_at, updated_at)"
|
||||
" VALUES"
|
||||
" (%s, %s, %s, %s, %s, NOW(), NOW()),"
|
||||
" (%s, %s, %s, %s, %s, NOW(), NOW())",
|
||||
(
|
||||
ticket_a_id, ACCOUNT_A_ID, user_a_id,
|
||||
"L1 RLS test ticket A", "open",
|
||||
ticket_b_id, ACCOUNT_B_ID, user_b_id,
|
||||
"L1 RLS test ticket B", "open",
|
||||
),
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"INSERT INTO l1_walk_sessions"
|
||||
" (id, account_id, created_by_user_id, ticket_id, ticket_kind,"
|
||||
" session_kind, status, started_at, last_step_at)"
|
||||
" VALUES"
|
||||
" (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW()),"
|
||||
" (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())",
|
||||
(
|
||||
walk_a_id, ACCOUNT_A_ID, user_a_id,
|
||||
"INT-A", "internal", "adhoc", "active",
|
||||
walk_b_id, ACCOUNT_B_ID, user_b_id,
|
||||
"INT-B", "internal", "adhoc", "active",
|
||||
),
|
||||
)
|
||||
|
||||
seed = {
|
||||
"ticket_a": ticket_a_id,
|
||||
"ticket_b": ticket_b_id,
|
||||
"walk_a": walk_a_id,
|
||||
"walk_b": walk_b_id,
|
||||
"user_a": user_a_id,
|
||||
"user_b": user_b_id,
|
||||
}
|
||||
|
||||
yield seed
|
||||
|
||||
# Cleanup in reverse FK order.
|
||||
# Delete all child rows for both test accounts before removing users —
|
||||
# other test modules (test_rls_isolation.py) may have seeded rows for
|
||||
# these same accounts, so we clean by account_id rather than by row ID.
|
||||
cur.execute(
|
||||
"DELETE FROM l1_walk_sessions WHERE account_id IN (%s, %s)",
|
||||
(ACCOUNT_A_ID, ACCOUNT_B_ID),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM internal_tickets WHERE account_id IN (%s, %s)",
|
||||
(ACCOUNT_A_ID, ACCOUNT_B_ID),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM users WHERE email IN (%s, %s)",
|
||||
("l1-rls-a@example.com", "l1-rls-b@example.com"),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM accounts WHERE id IN (%s, %s)"
|
||||
" AND display_code IN ('RLSA0001', 'RLSB0001')",
|
||||
(ACCOUNT_A_ID, ACCOUNT_B_ID),
|
||||
)
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-test helper: open an app-role connection with a given tenant context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _app_conn(account_id: str | None = None) -> psycopg2.extensions.connection:
|
||||
"""Open a psycopg2 connection as resolutionflow_app.
|
||||
|
||||
If account_id is given, SET LOCAL app.current_account_id so RLS applies
|
||||
to the given tenant. Callers must begin a transaction first.
|
||||
"""
|
||||
conn = psycopg2.connect(**_app_dsn())
|
||||
conn.autocommit = False
|
||||
cur = conn.cursor()
|
||||
if account_id:
|
||||
cur.execute(
|
||||
"SELECT set_config('app.current_account_id', %s, false)",
|
||||
(account_id,),
|
||||
)
|
||||
cur.close()
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# internal_tickets — read isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_l1_user_cannot_read_other_accounts_internal_tickets(l1_rls_seed):
|
||||
"""RLS USING: Account A context must not see Account B's tickets."""
|
||||
conn = _app_conn(ACCOUNT_A_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM internal_tickets WHERE id = %s",
|
||||
(l1_rls_seed["ticket_b"],),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 0, (
|
||||
"Account A must not read Account B's internal_tickets"
|
||||
)
|
||||
|
||||
|
||||
def test_internal_tickets_account_a_can_see_own_rows(l1_rls_seed):
|
||||
"""Positive check: Account A can read its own internal_tickets."""
|
||||
conn = _app_conn(ACCOUNT_A_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM internal_tickets WHERE id = %s",
|
||||
(l1_rls_seed["ticket_a"],),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 1, (
|
||||
"Account A must be able to read its own internal_tickets"
|
||||
)
|
||||
|
||||
|
||||
def test_internal_tickets_no_context_sees_nothing(l1_rls_seed):
|
||||
"""Fail-closed: no tenant context → zero internal_tickets rows visible."""
|
||||
conn = _app_conn() # no account_id
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM internal_tickets WHERE id IN (%s, %s)",
|
||||
(l1_rls_seed["ticket_a"], l1_rls_seed["ticket_b"]),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 0, (
|
||||
"No-context connection must not see any internal_tickets"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# l1_walk_sessions — read isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_l1_user_cannot_read_other_accounts_walk_sessions(l1_rls_seed):
|
||||
"""RLS USING: Account A context must not see Account B's walk sessions."""
|
||||
conn = _app_conn(ACCOUNT_A_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM l1_walk_sessions WHERE id = %s",
|
||||
(l1_rls_seed["walk_b"],),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 0, (
|
||||
"Account A must not read Account B's l1_walk_sessions"
|
||||
)
|
||||
|
||||
|
||||
def test_l1_walk_sessions_account_a_can_see_own_rows(l1_rls_seed):
|
||||
"""Positive check: Account A can read its own l1_walk_sessions."""
|
||||
conn = _app_conn(ACCOUNT_A_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM l1_walk_sessions WHERE id = %s",
|
||||
(l1_rls_seed["walk_a"],),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 1, (
|
||||
"Account A must be able to read its own l1_walk_sessions"
|
||||
)
|
||||
|
||||
|
||||
def test_l1_walk_sessions_no_context_sees_nothing(l1_rls_seed):
|
||||
"""Fail-closed: no tenant context → zero l1_walk_sessions rows visible."""
|
||||
conn = _app_conn() # no account_id
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"SELECT id FROM l1_walk_sessions WHERE id IN (%s, %s)",
|
||||
(l1_rls_seed["walk_a"], l1_rls_seed["walk_b"]),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
assert len(rows) == 0, (
|
||||
"No-context connection must not see any l1_walk_sessions"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# internal_tickets — WITH CHECK (cross-tenant INSERT rejection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_with_check_blocks_cross_tenant_insert_internal_tickets(l1_rls_seed):
|
||||
"""RLS WITH CHECK: INSERT with account_id = A under context B is rejected.
|
||||
|
||||
psycopg2 raises InsufficientPrivilege (pgcode '42501') when a row
|
||||
violates FORCE ROW LEVEL SECURITY WITH CHECK.
|
||||
"""
|
||||
new_id = str(uuid.uuid4())
|
||||
user_b_id = l1_rls_seed["user_b"]
|
||||
|
||||
conn = _app_conn(ACCOUNT_B_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
with pytest.raises(psycopg2.errors.InsufficientPrivilege):
|
||||
cur.execute(
|
||||
"INSERT INTO internal_tickets"
|
||||
" (id, account_id, created_by_user_id, problem_statement,"
|
||||
" status, created_at, updated_at)"
|
||||
" VALUES (%s, %s, %s, %s, %s, NOW(), NOW())",
|
||||
(
|
||||
new_id, ACCOUNT_A_ID, user_b_id,
|
||||
"Cross-tenant injection attempt", "open",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# l1_walk_sessions — WITH CHECK (cross-tenant INSERT rejection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_with_check_blocks_cross_tenant_insert_l1_walk_sessions(l1_rls_seed):
|
||||
"""RLS WITH CHECK: INSERT with account_id = A under context B is rejected."""
|
||||
new_id = str(uuid.uuid4())
|
||||
user_b_id = l1_rls_seed["user_b"]
|
||||
|
||||
conn = _app_conn(ACCOUNT_B_ID)
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
with pytest.raises(psycopg2.errors.InsufficientPrivilege):
|
||||
cur.execute(
|
||||
"INSERT INTO l1_walk_sessions"
|
||||
" (id, account_id, created_by_user_id, ticket_id,"
|
||||
" ticket_kind, session_kind, status, started_at, last_step_at)"
|
||||
" VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())",
|
||||
(
|
||||
new_id, ACCOUNT_A_ID, user_b_id,
|
||||
"INT-cross", "internal", "adhoc", "active",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
119
backend/tests/test_l1_session_cleanup.py
Normal file
119
backend/tests/test_l1_session_cleanup.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Tests for the l1_session_cleanup job."""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.l1_walk_session import L1WalkSession
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.services.l1_session_cleanup import flip_stale_sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _make_account(db: AsyncSession) -> Account:
|
||||
import secrets
|
||||
import string
|
||||
code = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(8))
|
||||
a = Account(id=uuid.uuid4(), name="Test", display_code=code)
|
||||
db.add(a)
|
||||
await db.flush()
|
||||
return a
|
||||
|
||||
|
||||
async def _make_user(db: AsyncSession, *, account_id: uuid.UUID) -> User:
|
||||
u = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"user-{uuid.uuid4()}@example.com",
|
||||
name="L1",
|
||||
account_id=account_id,
|
||||
account_role="l1_tech",
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(u)
|
||||
await db.flush()
|
||||
return u
|
||||
|
||||
|
||||
async def _make_session(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
status: str = "active",
|
||||
last_step_at: datetime | None = None,
|
||||
) -> L1WalkSession:
|
||||
now = datetime.now(timezone.utc)
|
||||
session = L1WalkSession(
|
||||
id=uuid.uuid4(),
|
||||
account_id=account_id,
|
||||
created_by_user_id=user_id,
|
||||
ticket_id="t",
|
||||
ticket_kind="internal",
|
||||
session_kind="adhoc",
|
||||
status=status,
|
||||
started_at=now,
|
||||
last_step_at=last_step_at or now,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flip_stale_sessions_only_affects_old_active_rows(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
user = await _make_user(test_db, account_id=account.id)
|
||||
|
||||
# 1. Stale active (>24h ago) — should flip
|
||||
stale = await _make_session(
|
||||
test_db, account_id=account.id, user_id=user.id,
|
||||
status="active",
|
||||
last_step_at=datetime.now(timezone.utc) - timedelta(hours=25),
|
||||
)
|
||||
# 2. Fresh active (1h ago) — should stay active
|
||||
fresh = await _make_session(
|
||||
test_db, account_id=account.id, user_id=user.id,
|
||||
status="active",
|
||||
last_step_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
# 3. Already-resolved (old) — should stay resolved, not flip
|
||||
already_resolved = await _make_session(
|
||||
test_db, account_id=account.id, user_id=user.id,
|
||||
status="resolved",
|
||||
last_step_at=datetime.now(timezone.utc) - timedelta(hours=48),
|
||||
)
|
||||
await test_db.commit()
|
||||
|
||||
count = await flip_stale_sessions(test_db)
|
||||
assert count == 1
|
||||
|
||||
await test_db.refresh(stale)
|
||||
await test_db.refresh(fresh)
|
||||
await test_db.refresh(already_resolved)
|
||||
assert stale.status == "abandoned"
|
||||
assert fresh.status == "active"
|
||||
assert already_resolved.status == "resolved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flip_stale_sessions_returns_zero_when_none_stale(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
user = await _make_user(test_db, account_id=account.id)
|
||||
await _make_session(
|
||||
test_db, account_id=account.id, user_id=user.id,
|
||||
status="active",
|
||||
last_step_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
await test_db.commit()
|
||||
count = await flip_stale_sessions(test_db)
|
||||
assert count == 0
|
||||
917
backend/tests/test_l1_session_service.py
Normal file
917
backend/tests/test_l1_session_service.py
Normal file
@@ -0,0 +1,917 @@
|
||||
"""Tests for l1_session_service start_* functions (T12), record_step/update_notes (T13), resolve/escalate (T14)."""
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.flow_proposal import FlowProposal
|
||||
from app.services.l1_session_service import (
|
||||
start_flow_session,
|
||||
start_proposal_session,
|
||||
start_adhoc_session,
|
||||
_resolve_acting_as,
|
||||
record_step,
|
||||
update_notes,
|
||||
resolve,
|
||||
escalate,
|
||||
escalate_without_walk,
|
||||
)
|
||||
from app.services import internal_ticket_service
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _make_account(db: AsyncSession) -> Account:
|
||||
s = str(uuid.uuid4())[:8]
|
||||
account = Account(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Test Account {s}",
|
||||
display_code=s[:8].upper(),
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
return account
|
||||
|
||||
|
||||
async def _make_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: uuid.UUID,
|
||||
account_role: str = "l1_tech",
|
||||
can_cover_l1: bool = False,
|
||||
) -> User:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"user-{uuid.uuid4()}@example.com",
|
||||
name="Test User",
|
||||
account_id=account_id,
|
||||
account_role=account_role,
|
||||
role="engineer",
|
||||
is_active=True,
|
||||
can_cover_l1=can_cover_l1,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def _make_tree(db: AsyncSession, *, account_id: uuid.UUID, author_id: uuid.UUID) -> Tree:
|
||||
tree = Tree(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Flow",
|
||||
account_id=account_id,
|
||||
author_id=author_id,
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"nodes": [], "edges": []},
|
||||
visibility="team",
|
||||
status="published",
|
||||
)
|
||||
db.add(tree)
|
||||
await db.flush()
|
||||
return tree
|
||||
|
||||
|
||||
async def _make_ai_session(db: AsyncSession, *, user_id: uuid.UUID, account_id: uuid.UUID) -> AISession:
|
||||
ai_session = AISession(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
session_type="chat",
|
||||
intake_type="free_text",
|
||||
intake_content={"text": "test"},
|
||||
status="active",
|
||||
confidence_tier="discovery",
|
||||
conversation_messages=[],
|
||||
)
|
||||
db.add(ai_session)
|
||||
await db.flush()
|
||||
return ai_session
|
||||
|
||||
|
||||
async def _make_proposal(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account_id: uuid.UUID,
|
||||
source_session_id: uuid.UUID,
|
||||
) -> FlowProposal:
|
||||
proposal = FlowProposal(
|
||||
id=uuid.uuid4(),
|
||||
account_id=account_id,
|
||||
source_session_id=source_session_id,
|
||||
proposal_type="new_flow",
|
||||
title="Test Proposal",
|
||||
proposed_flow_data={"nodes": [], "edges": []},
|
||||
source="manual_draft",
|
||||
status="pending",
|
||||
)
|
||||
db.add(proposal)
|
||||
await db.flush()
|
||||
return proposal
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for _resolve_acting_as (no DB needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_acting_as_for_engineer_returns_coverage_tag():
|
||||
user = User(account_role="engineer")
|
||||
assert _resolve_acting_as(user) == "l1_coverage"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_acting_as_for_l1_tech_returns_none():
|
||||
user = User(account_role="l1_tech")
|
||||
assert _resolve_acting_as(user) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_acting_as_for_owner_returns_none():
|
||||
user = User(account_role="owner")
|
||||
assert _resolve_acting_as(user) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests (real DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_session_creates_active_flow_session(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="ticket-abc",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
assert session.session_kind == "flow"
|
||||
assert session.flow_id == tree.id
|
||||
assert session.flow_proposal_id is None
|
||||
assert session.status == "active"
|
||||
assert session.walked_path == []
|
||||
assert session.walk_notes == []
|
||||
assert session.acting_as is None # l1_tech native
|
||||
assert session.ticket_id == "ticket-abc"
|
||||
assert session.ticket_kind == "internal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_proposal_session_creates_active_proposal_session(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ai_session = await _make_ai_session(test_db, user_id=l1.id, account_id=account.id)
|
||||
proposal = await _make_proposal(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
source_session_id=ai_session.id,
|
||||
)
|
||||
|
||||
session = await start_proposal_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_proposal_id=proposal.id,
|
||||
ticket_id="ticket-xyz",
|
||||
ticket_kind="psa",
|
||||
)
|
||||
assert session.session_kind == "proposal"
|
||||
assert session.flow_proposal_id == proposal.id
|
||||
assert session.flow_id is None
|
||||
assert session.status == "active"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_adhoc_session_no_flow_no_proposal(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="ticket-adhoc",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
assert session.session_kind == "adhoc"
|
||||
assert session.flow_id is None
|
||||
assert session.flow_proposal_id is None
|
||||
assert session.walked_path == []
|
||||
assert session.walk_notes == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_with_coverage_gets_acting_as_tag(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
eng = await _make_user(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
can_cover_l1=True,
|
||||
)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=eng,
|
||||
ticket_id="t",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
assert session.acting_as == "l1_coverage"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T13: record_step and update_notes tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_step_appends_to_walked_path(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="t1",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
updated = await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n1",
|
||||
question="Is the device powered on?",
|
||||
answer="yes",
|
||||
)
|
||||
assert len(updated.walked_path) == 1
|
||||
assert updated.walked_path[0] == {
|
||||
"node_id": "n1",
|
||||
"question": "Is the device powered on?",
|
||||
"answer": "yes",
|
||||
"l1_note": None,
|
||||
}
|
||||
assert updated.current_node_id == "n1"
|
||||
assert updated.last_step_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_step_two_sequential_steps_accumulate(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="t2",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n1",
|
||||
question="Step 1?",
|
||||
answer="yes",
|
||||
)
|
||||
updated = await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n2",
|
||||
question="Step 2?",
|
||||
answer="no",
|
||||
)
|
||||
assert len(updated.walked_path) == 2
|
||||
assert updated.walked_path[0]["node_id"] == "n1"
|
||||
assert updated.walked_path[1]["node_id"] == "n2"
|
||||
assert updated.current_node_id == "n2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_step_blocks_adhoc(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="t3",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
with pytest.raises(ValueError, match="adhoc"):
|
||||
await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n1",
|
||||
question="Q?",
|
||||
answer="yes",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_step_blocks_inactive_session(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="t4",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
# Manually mark resolved to simulate inactive state
|
||||
session.status = "resolved"
|
||||
await test_db.flush()
|
||||
with pytest.raises(ValueError, match="not active"):
|
||||
await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n1",
|
||||
question="Q?",
|
||||
answer="yes",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_step_includes_note_when_provided(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="t5",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
updated = await record_step(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
node_id="n1",
|
||||
question="Q?",
|
||||
answer="yes",
|
||||
note="Customer mentioned it started yesterday",
|
||||
)
|
||||
assert updated.walked_path[0]["l1_note"] == "Customer mentioned it started yesterday"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_notes_replaces_walk_notes(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="t6",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
new_notes = [{"timestamp": "2026-05-28T10:00:00Z", "content": "Customer rebooted"}]
|
||||
updated = await update_notes(test_db, session_id=session.id, notes=new_notes)
|
||||
assert updated.walk_notes == new_notes
|
||||
assert updated.last_step_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_notes_raises_if_session_not_found(test_db: AsyncSession):
|
||||
missing_id = uuid.uuid4()
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await update_notes(test_db, session_id=missing_id, notes=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_notes_raises_if_session_not_active(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="t7",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
session.status = "escalated"
|
||||
await test_db.flush()
|
||||
with pytest.raises(ValueError, match="not active"):
|
||||
await update_notes(test_db, session_id=session.id, notes=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_notes_size_cap(test_db: AsyncSession):
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="t8",
|
||||
ticket_kind="internal",
|
||||
)
|
||||
# Create a notes payload larger than 256KB
|
||||
big_content = "x" * (256 * 1024 + 100)
|
||||
notes = [{"timestamp": "2026-05-28T10:00:00Z", "content": big_content}]
|
||||
with pytest.raises(ValueError, match="256KB"):
|
||||
await update_notes(test_db, session_id=session.id, notes=notes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T14: resolve, escalate, escalate_without_walk tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _make_internal_ticket(db: AsyncSession, *, account_id: uuid.UUID, user_id: uuid.UUID) -> object:
|
||||
"""Create an internal ticket and return it."""
|
||||
return await internal_ticket_service.create_ticket(
|
||||
db,
|
||||
account_id=account_id,
|
||||
created_by_user_id=user_id,
|
||||
problem_statement="Customer cannot log in",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_proposal_helpful_flips_validated_by_outcome(test_db: AsyncSession):
|
||||
"""resolve(helpful=True) on a proposal session sets validated_by_outcome=True."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
ai_session = await _make_ai_session(test_db, user_id=l1.id, account_id=account.id)
|
||||
proposal = await _make_proposal(test_db, account_id=account.id, source_session_id=ai_session.id)
|
||||
|
||||
session = await start_proposal_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_proposal_id=proposal.id,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
resolved = await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=True,
|
||||
resolution_notes="Issue fixed by proposal walk",
|
||||
)
|
||||
assert resolved.status == "resolved"
|
||||
assert resolved.helpful is True
|
||||
assert resolved.resolution_notes == "Issue fixed by proposal walk"
|
||||
assert resolved.resolved_at is not None
|
||||
|
||||
# Proposal should now be validated
|
||||
await test_db.refresh(proposal)
|
||||
assert proposal.validated_by_outcome is True
|
||||
|
||||
# Internal ticket should be closed
|
||||
await test_db.refresh(ticket)
|
||||
assert ticket.status == "resolved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_proposal_not_helpful_leaves_validated_by_outcome_false(test_db: AsyncSession):
|
||||
"""resolve(helpful=False) on a proposal session does NOT flip validated_by_outcome."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
ai_session = await _make_ai_session(test_db, user_id=l1.id, account_id=account.id)
|
||||
proposal = await _make_proposal(test_db, account_id=account.id, source_session_id=ai_session.id)
|
||||
|
||||
session = await start_proposal_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_proposal_id=proposal.id,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
resolved = await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=False,
|
||||
resolution_notes="Proposal did not help",
|
||||
)
|
||||
assert resolved.helpful is False
|
||||
await test_db.refresh(proposal)
|
||||
assert proposal.validated_by_outcome is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_flow_session_closes_ticket_no_proposal_update(test_db: AsyncSession):
|
||||
"""resolve on a flow session closes the ticket and does not touch proposals."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
resolved = await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=True,
|
||||
resolution_notes="Flow resolved the issue",
|
||||
)
|
||||
assert resolved.status == "resolved"
|
||||
assert resolved.session_kind == "flow"
|
||||
assert resolved.flow_proposal_id is None
|
||||
|
||||
await test_db.refresh(ticket)
|
||||
assert ticket.status == "resolved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_adhoc_session_closes_ticket(test_db: AsyncSession):
|
||||
"""resolve on an adhoc session closes the ticket with no proposal interaction."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
resolved = await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=True,
|
||||
resolution_notes="Adhoc resolved",
|
||||
)
|
||||
assert resolved.status == "resolved"
|
||||
await test_db.refresh(ticket)
|
||||
assert ticket.status == "resolved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_psa_session_no_ticket_update(test_db: AsyncSession):
|
||||
"""resolve on a PSA-backed session does not attempt to update an internal ticket."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id="psa-external-id-123",
|
||||
ticket_kind="psa",
|
||||
)
|
||||
resolved = await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=True,
|
||||
resolution_notes="PSA ticket resolved externally",
|
||||
)
|
||||
assert resolved.status == "resolved"
|
||||
assert resolved.ticket_kind == "psa"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_raises_on_missing_session(test_db: AsyncSession):
|
||||
"""resolve raises ValueError when session does not exist."""
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await resolve(
|
||||
test_db,
|
||||
session_id=uuid.uuid4(),
|
||||
helpful=True,
|
||||
resolution_notes="N/A",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_raises_on_inactive_session(test_db: AsyncSession):
|
||||
"""resolve raises ValueError when session is not active."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
session.status = "escalated"
|
||||
await test_db.flush()
|
||||
|
||||
with pytest.raises(ValueError, match="not active"):
|
||||
await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=False,
|
||||
resolution_notes="N/A",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_marks_session_and_ticket_as_escalated(test_db: AsyncSession):
|
||||
"""escalate sets session status=escalated and closes the internal ticket as escalated."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
tree = await _make_tree(test_db, account_id=account.id, author_id=l1.id)
|
||||
|
||||
session = await start_flow_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
flow_id=tree.id,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
escalated = await escalate(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
reason="Customer reported intermittent failure not covered by flow",
|
||||
reason_category="out_of_scope",
|
||||
)
|
||||
assert escalated.status == "escalated"
|
||||
assert escalated.escalation_reason == "Customer reported intermittent failure not covered by flow"
|
||||
assert escalated.escalation_reason_category == "out_of_scope"
|
||||
assert escalated.resolved_at is not None
|
||||
assert escalated.last_step_at is not None
|
||||
|
||||
await test_db.refresh(ticket)
|
||||
assert ticket.status == "escalated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_raises_on_missing_session(test_db: AsyncSession):
|
||||
"""escalate raises ValueError when session does not exist."""
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await escalate(
|
||||
test_db,
|
||||
session_id=uuid.uuid4(),
|
||||
reason="Some reason",
|
||||
reason_category="unknown",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_raises_on_inactive_session(test_db: AsyncSession):
|
||||
"""escalate raises ValueError when session is already inactive."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
session.status = "resolved"
|
||||
await test_db.flush()
|
||||
|
||||
with pytest.raises(ValueError, match="not active"):
|
||||
await escalate(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
reason="Too late",
|
||||
reason_category="other",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk_creates_escalated_adhoc_session(test_db: AsyncSession):
|
||||
"""escalate_without_walk creates an immediately-escalated session with empty walked_path."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
session = await escalate_without_walk(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
reason_category="no_kb_content",
|
||||
reason="No KB article matched this issue",
|
||||
)
|
||||
assert session.status == "escalated"
|
||||
assert session.session_kind == "adhoc"
|
||||
assert session.walked_path == []
|
||||
assert session.escalation_reason == "No KB article matched this issue"
|
||||
assert session.escalation_reason_category == "no_kb_content"
|
||||
assert session.resolved_at is not None
|
||||
assert session.last_step_at is not None
|
||||
assert session.account_id == account.id
|
||||
assert session.created_by_user_id == l1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk_escalates_internal_ticket(test_db: AsyncSession):
|
||||
"""escalate_without_walk marks the internal ticket as escalated."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
await escalate_without_walk(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
reason_category="no_kb_content",
|
||||
)
|
||||
await test_db.refresh(ticket)
|
||||
assert ticket.status == "escalated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk_psa_does_not_touch_internal_ticket(test_db: AsyncSession):
|
||||
"""escalate_without_walk with ticket_kind='psa' does not update internal tickets."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
|
||||
session = await escalate_without_walk(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id="psa-ticket-999",
|
||||
ticket_kind="psa",
|
||||
reason_category="no_kb_content",
|
||||
)
|
||||
assert session.status == "escalated"
|
||||
assert session.ticket_kind == "psa"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk_reason_is_optional(test_db: AsyncSession):
|
||||
"""escalate_without_walk works without a reason string."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(test_db, account_id=account.id, user_id=l1.id)
|
||||
|
||||
session = await escalate_without_walk(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
reason_category="no_kb_content",
|
||||
)
|
||||
assert session.escalation_reason is None
|
||||
assert session.escalation_reason_category == "no_kb_content"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T14 audit log tests (spec §5.6.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_writes_audit_log_with_acting_as(test_db: AsyncSession):
|
||||
"""resolve() writes an audit_logs row with acting_as='l1_coverage' for engineers."""
|
||||
account = await _make_account(test_db)
|
||||
eng = await _make_user(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
can_cover_l1=True,
|
||||
)
|
||||
ticket = await _make_internal_ticket(
|
||||
test_db, account_id=account.id, user_id=eng.id
|
||||
)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=eng,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=True,
|
||||
resolution_notes="Coverage engineer resolved",
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(
|
||||
AuditLog.action == "l1.session.resolve",
|
||||
AuditLog.resource_id == session.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.acting_as == "l1_coverage"
|
||||
assert row.user_id == eng.id
|
||||
assert row.account_id == account.id
|
||||
assert row.details["helpful"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_writes_audit_log_native_l1_acting_as_null(
|
||||
test_db: AsyncSession,
|
||||
):
|
||||
"""resolve() writes an audit_logs row with acting_as=None for native l1_tech."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id, account_role="l1_tech")
|
||||
ticket = await _make_internal_ticket(
|
||||
test_db, account_id=account.id, user_id=l1.id
|
||||
)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
await resolve(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
helpful=False,
|
||||
resolution_notes="Native L1 resolved",
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(
|
||||
AuditLog.action == "l1.session.resolve",
|
||||
AuditLog.resource_id == session.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.acting_as is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_writes_audit_log(test_db: AsyncSession):
|
||||
"""escalate() writes an audit_logs row with action='l1.session.escalate'."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(
|
||||
test_db, account_id=account.id, user_id=l1.id
|
||||
)
|
||||
session = await start_adhoc_session(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
)
|
||||
await escalate(
|
||||
test_db,
|
||||
session_id=session.id,
|
||||
reason="Beyond scope",
|
||||
reason_category="out_of_scope",
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(
|
||||
AuditLog.action == "l1.session.escalate",
|
||||
AuditLog.resource_id == session.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.details["escalation_reason_category"] == "out_of_scope"
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_without_walk_writes_audit_log(test_db: AsyncSession):
|
||||
"""escalate_without_walk() writes an audit_logs row."""
|
||||
account = await _make_account(test_db)
|
||||
l1 = await _make_user(test_db, account_id=account.id)
|
||||
ticket = await _make_internal_ticket(
|
||||
test_db, account_id=account.id, user_id=l1.id
|
||||
)
|
||||
session = await escalate_without_walk(
|
||||
test_db,
|
||||
account_id=account.id,
|
||||
user=l1,
|
||||
ticket_id=str(ticket.id),
|
||||
ticket_kind="internal",
|
||||
reason_category="no_kb_content",
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(
|
||||
AuditLog.action == "l1.session.escalate_no_walk",
|
||||
AuditLog.resource_id == session.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
assert row.details["escalation_reason_category"] == "no_kb_content"
|
||||
@@ -23,6 +23,7 @@ from pathlib import Path
|
||||
from urllib.parse import unquote, urlsplit
|
||||
|
||||
import asyncpg
|
||||
import psycopg2
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
@@ -80,7 +81,22 @@ def _ensure_rls_schema():
|
||||
public schema using Base.metadata.create_all, which does not enable RLS
|
||||
or create DB roles. This fixture re-runs 'alembic upgrade head' so that
|
||||
the full migration-managed schema (including RLS policies) is in place.
|
||||
|
||||
We drop and recreate the public schema first so that any tables left behind
|
||||
by a prior create_all-based test_db run don't conflict with alembic's
|
||||
migration tracking.
|
||||
"""
|
||||
# Drop and recreate the schema to ensure a clean slate for alembic.
|
||||
admin_dsn = dict(
|
||||
host=_DB_HOST, port=_DB_PORT, dbname=_DB_NAME,
|
||||
user=_ADMIN_USER, password=_ADMIN_PASSWORD,
|
||||
)
|
||||
with psycopg2.connect(**admin_dsn) as conn:
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DROP SCHEMA public CASCADE")
|
||||
cur.execute("CREATE SCHEMA public")
|
||||
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = _DATABASE_TEST_URL
|
||||
@@ -131,15 +147,18 @@ async def seed_rls_test_data(admin_conn):
|
||||
user_b_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO users (
|
||||
id, email, password_hash, name, role, is_active, account_id,
|
||||
account_role, created_at
|
||||
id, email, password_hash, name, role,
|
||||
is_super_admin, is_team_admin, is_service_account, must_change_password,
|
||||
is_active, account_id, account_role, timezone, created_at
|
||||
) VALUES
|
||||
('{user_a_id}', 'rls-user-a@example.com',
|
||||
'placeholder', 'RLS User A', 'engineer', TRUE,
|
||||
'{ACCOUNT_A_ID}', 'engineer', NOW()),
|
||||
'placeholder', 'RLS User A', 'engineer',
|
||||
FALSE, FALSE, FALSE, FALSE,
|
||||
TRUE, '{ACCOUNT_A_ID}', 'engineer', 'UTC', NOW()),
|
||||
('{user_b_id}', 'rls-user-b@example.com',
|
||||
'placeholder', 'RLS User B', 'engineer', TRUE,
|
||||
'{ACCOUNT_B_ID}', 'engineer', NOW())
|
||||
'placeholder', 'RLS User B', 'engineer',
|
||||
FALSE, FALSE, FALSE, FALSE,
|
||||
TRUE, '{ACCOUNT_B_ID}', 'engineer', 'UTC', NOW())
|
||||
ON CONFLICT (email) DO NOTHING
|
||||
""")
|
||||
|
||||
|
||||
195
backend/tests/test_seat_enforcement.py
Normal file
195
backend/tests/test_seat_enforcement.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Integration tests for the seat_enforcement service.
|
||||
|
||||
Uses the test_db fixture (real async DB, fresh schema per test) to exercise
|
||||
the SQL counting logic in check_seat_available / get_seat_usage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.services.seat_enforcement import check_seat_available, get_seat_usage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test-local DB helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _make_account(db: AsyncSession, *, suffix: str | None = None) -> Account:
|
||||
"""Create and flush a minimal Account row."""
|
||||
s = suffix or str(uuid.uuid4())[:8]
|
||||
account = Account(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Test Account {s}",
|
||||
display_code=s[:8],
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
return account
|
||||
|
||||
|
||||
async def _make_subscription(
|
||||
db: AsyncSession,
|
||||
account: Account,
|
||||
*,
|
||||
seat_limit: int | None = None,
|
||||
l1_seat_limit: int | None = None,
|
||||
) -> Subscription:
|
||||
"""Create and flush a Subscription for the given account."""
|
||||
sub = Subscription(
|
||||
account_id=account.id,
|
||||
plan="pro",
|
||||
status="active",
|
||||
seat_limit=seat_limit,
|
||||
l1_seat_limit=l1_seat_limit,
|
||||
)
|
||||
db.add(sub)
|
||||
await db.flush()
|
||||
return sub
|
||||
|
||||
|
||||
async def _make_user(
|
||||
db: AsyncSession,
|
||||
account: Account,
|
||||
*,
|
||||
account_role: str = "engineer",
|
||||
is_active: bool = True,
|
||||
suffix: str | None = None,
|
||||
) -> User:
|
||||
"""Create and flush a User row in the given account."""
|
||||
s = suffix or str(uuid.uuid4())[:8]
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"user-{s}@example.com",
|
||||
name=f"User {s}",
|
||||
account_id=account.id,
|
||||
account_role=account_role,
|
||||
role="engineer",
|
||||
is_active=is_active,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_seat_available_when_under_limit(test_db: AsyncSession):
|
||||
"""check_seat_available returns available=True when current < seat_limit."""
|
||||
account = await _make_account(test_db)
|
||||
sub = await _make_subscription(test_db, account, seat_limit=5)
|
||||
|
||||
for _ in range(3):
|
||||
await _make_user(test_db, account, account_role="engineer")
|
||||
|
||||
result = await check_seat_available(account, sub, "engineer", test_db)
|
||||
|
||||
assert result.available is True
|
||||
assert result.current == 3
|
||||
assert result.limit == 5
|
||||
assert result.role == "engineer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_seat_unavailable_when_at_limit(test_db: AsyncSession):
|
||||
"""check_seat_available returns available=False when current == seat_limit."""
|
||||
account = await _make_account(test_db)
|
||||
sub = await _make_subscription(test_db, account, seat_limit=2)
|
||||
|
||||
for _ in range(2):
|
||||
await _make_user(test_db, account, account_role="engineer")
|
||||
|
||||
result = await check_seat_available(account, sub, "engineer", test_db)
|
||||
|
||||
assert result.available is False
|
||||
assert result.current == 2
|
||||
assert result.limit == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_l1_uses_separate_seat_limit(test_db: AsyncSession):
|
||||
"""Engineer limit hit does not affect l1_tech availability."""
|
||||
account = await _make_account(test_db)
|
||||
# seat_limit exhausted, l1_seat_limit still has room
|
||||
sub = await _make_subscription(test_db, account, seat_limit=2, l1_seat_limit=3)
|
||||
|
||||
# Fill engineer seats to the limit
|
||||
for _ in range(2):
|
||||
await _make_user(test_db, account, account_role="engineer")
|
||||
|
||||
# Add one L1 user (below limit)
|
||||
await _make_user(test_db, account, account_role="l1_tech")
|
||||
|
||||
eng_result = await check_seat_available(account, sub, "engineer", test_db)
|
||||
l1_result = await check_seat_available(account, sub, "l1_tech", test_db)
|
||||
|
||||
assert eng_result.available is False, "engineer seats should be full"
|
||||
assert eng_result.current == 2
|
||||
|
||||
assert l1_result.available is True, "l1_tech seats should still be available"
|
||||
assert l1_result.current == 1
|
||||
assert l1_result.limit == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlimited_seat_limit_is_always_available(test_db: AsyncSession):
|
||||
"""seat_limit=None means unlimited; available=True regardless of count."""
|
||||
account = await _make_account(test_db)
|
||||
sub = await _make_subscription(test_db, account, seat_limit=None)
|
||||
|
||||
# Add many engineer users
|
||||
for _ in range(10):
|
||||
await _make_user(test_db, account, account_role="engineer")
|
||||
|
||||
result = await check_seat_available(account, sub, "engineer", test_db)
|
||||
|
||||
assert result.available is True
|
||||
assert result.current == 10
|
||||
assert result.limit is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_seat_usage_returns_engineer_l1_tuple(test_db: AsyncSession):
|
||||
"""get_seat_usage returns a (engineer, l1_tech) tuple in the correct order."""
|
||||
account = await _make_account(test_db)
|
||||
sub = await _make_subscription(test_db, account, seat_limit=5, l1_seat_limit=3)
|
||||
|
||||
await _make_user(test_db, account, account_role="engineer")
|
||||
await _make_user(test_db, account, account_role="l1_tech")
|
||||
await _make_user(test_db, account, account_role="l1_tech")
|
||||
|
||||
eng, l1 = await get_seat_usage(account, sub, test_db)
|
||||
|
||||
assert eng.role == "engineer"
|
||||
assert eng.current == 1
|
||||
assert eng.limit == 5
|
||||
assert eng.available is True
|
||||
|
||||
assert l1.role == "l1_tech"
|
||||
assert l1.current == 2
|
||||
assert l1.limit == 3
|
||||
assert l1.available is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_users_not_counted(test_db: AsyncSession):
|
||||
"""Inactive (is_active=False) users are excluded from the seat count."""
|
||||
account = await _make_account(test_db)
|
||||
sub = await _make_subscription(test_db, account, seat_limit=3)
|
||||
|
||||
# 1 active, 2 inactive
|
||||
await _make_user(test_db, account, account_role="engineer", is_active=True)
|
||||
await _make_user(test_db, account, account_role="engineer", is_active=False)
|
||||
await _make_user(test_db, account, account_role="engineer", is_active=False)
|
||||
|
||||
result = await check_seat_available(account, sub, "engineer", test_db)
|
||||
|
||||
assert result.current == 1
|
||||
assert result.available is True
|
||||
Reference in New Issue
Block a user