diff --git a/backend/alembic/versions/002_add_invite_codes.py b/backend/alembic/versions/002_add_invite_codes.py new file mode 100644 index 00000000..630bd4b6 --- /dev/null +++ b/backend/alembic/versions/002_add_invite_codes.py @@ -0,0 +1,52 @@ +"""Add invite codes + +Revision ID: 002 +Revises: 7e00fa3c75c9 +Create Date: 2026-02-01 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '002' +down_revision: Union[str, None] = '7e00fa3c75c9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create invite_codes table + op.create_table( + 'invite_codes', + sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column('code', sa.String(16), nullable=False, unique=True, index=True), + sa.Column('created_by_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('users.id'), nullable=False), + sa.Column('used_by_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('users.id'), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('note', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('used_at', sa.DateTime(timezone=True), nullable=True), + ) + + # Add invite_code_id FK to users table + op.add_column('users', sa.Column('invite_code_id', postgresql.UUID(as_uuid=True), nullable=True)) + op.create_foreign_key( + 'fk_users_invite_code_id', + 'users', + 'invite_codes', + ['invite_code_id'], + ['id'] + ) + + +def downgrade() -> None: + # Remove FK and column from users + op.drop_constraint('fk_users_invite_code_id', 'users', type_='foreignkey') + op.drop_column('users', 'invite_code_id') + + # Drop invite_codes table + op.drop_table('invite_codes') diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index cd7ef183..97b79708 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -5,6 +5,7 @@ from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select +from app.core.config import settings from app.core.database import get_db from app.core.security import ( verify_password, @@ -14,6 +15,7 @@ from app.core.security import ( decode_token ) from app.models.user import User +from app.models.invite_code import InviteCode from app.schemas.user import UserCreate, UserResponse, UserLogin from app.schemas.token import Token from app.api.deps import get_current_user @@ -27,6 +29,39 @@ async def register( db: Annotated[AsyncSession, Depends(get_db)] ): """Register a new user.""" + # Validate invite code if required + invite_code_record = None + if settings.REQUIRE_INVITE_CODE: + if not user_data.invite_code: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invite code is required" + ) + + # Look up invite code (case-insensitive) + result = await db.execute( + select(InviteCode).where(InviteCode.code == user_data.invite_code.upper()) + ) + invite_code_record = result.scalar_one_or_none() + + if not invite_code_record: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid invite code" + ) + + if invite_code_record.is_used: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invite code has already been used" + ) + + if invite_code_record.is_expired: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invite code has expired" + ) + # Check if email already exists result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = result.scalar_one_or_none() @@ -41,9 +76,16 @@ async def register( email=user_data.email, password_hash=get_password_hash(user_data.password), name=user_data.name, - role=user_data.role # Use role from request (defaults to "engineer") + role=user_data.role, + invite_code_id=invite_code_record.id if invite_code_record else None ) db.add(new_user) + + # Mark invite code as used + if invite_code_record: + invite_code_record.used_by_id = new_user.id + invite_code_record.used_at = datetime.now(timezone.utc) + await db.commit() await db.refresh(new_user) diff --git a/backend/app/api/endpoints/invite.py b/backend/app/api/endpoints/invite.py new file mode 100644 index 00000000..9784b0df --- /dev/null +++ b/backend/app/api/endpoints/invite.py @@ -0,0 +1,96 @@ +from datetime import datetime, timezone +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.core.database import get_db +from app.models.user import User +from app.models.invite_code import InviteCode +from app.schemas.invite_code import InviteCodeCreate, InviteCodeResponse, InviteCodeValidation +from app.api.deps import require_admin + +router = APIRouter(prefix="/invites", tags=["invites"]) + + +@router.post("", response_model=InviteCodeResponse, status_code=status.HTTP_201_CREATED) +async def create_invite_code( + invite_data: InviteCodeCreate, + current_user: Annotated[User, Depends(require_admin)], + db: Annotated[AsyncSession, Depends(get_db)] +): + """Create a new invite code. Admin only.""" + invite_code = InviteCode( + created_by_id=current_user.id, + expires_at=invite_data.expires_at, + note=invite_data.note + ) + db.add(invite_code) + await db.commit() + await db.refresh(invite_code) + + return invite_code + + +@router.get("", response_model=list[InviteCodeResponse]) +async def list_invite_codes( + current_user: Annotated[User, Depends(require_admin)], + db: Annotated[AsyncSession, Depends(get_db)] +): + """List all invite codes. Admin only.""" + result = await db.execute( + select(InviteCode).order_by(InviteCode.created_at.desc()) + ) + invite_codes = result.scalars().all() + return invite_codes + + +@router.delete("/{code}", status_code=status.HTTP_204_NO_CONTENT) +async def revoke_invite_code( + code: str, + current_user: Annotated[User, Depends(require_admin)], + db: Annotated[AsyncSession, Depends(get_db)] +): + """Revoke (delete) an invite code. Admin only.""" + result = await db.execute( + select(InviteCode).where(InviteCode.code == code) + ) + invite_code = result.scalar_one_or_none() + + if not invite_code: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invite code not found" + ) + + if invite_code.is_used: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot revoke a used invite code" + ) + + await db.delete(invite_code) + await db.commit() + + +@router.get("/validate/{code}", response_model=InviteCodeValidation) +async def validate_invite_code( + code: str, + db: Annotated[AsyncSession, Depends(get_db)] +): + """Check if an invite code is valid. Public endpoint for UX.""" + result = await db.execute( + select(InviteCode).where(InviteCode.code == code.upper()) + ) + invite_code = result.scalar_one_or_none() + + if not invite_code: + return InviteCodeValidation(valid=False, message="Invalid invite code") + + if invite_code.is_used: + return InviteCodeValidation(valid=False, message="Invite code has already been used") + + if invite_code.is_expired: + return InviteCodeValidation(valid=False, message="Invite code has expired") + + return InviteCodeValidation(valid=True, message="Invite code is valid") diff --git a/backend/app/api/router.py b/backend/app/api/router.py index e1c5936c..7779bbbe 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,8 +1,9 @@ from fastapi import APIRouter -from app.api.endpoints import auth, trees, sessions +from app.api.endpoints import auth, trees, sessions, invite api_router = APIRouter() api_router.include_router(auth.router) api_router.include_router(trees.router) api_router.include_router(sessions.router) +api_router.include_router(invite.router) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index cd858f87..7be1b8a3 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -38,6 +38,9 @@ class Settings(BaseSettings): # Security BCRYPT_ROUNDS: int = 12 + # Registration + REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration + # CORS - set FRONTEND_URL in production (e.g., https://patherly.up.railway.app) CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173", "http://localhost:5174"] FRONTEND_URL: Optional[str] = None diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 9c752f54..2b91354c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -3,5 +3,6 @@ from .team import Team from .tree import Tree from .session import Session from .attachment import Attachment +from .invite_code import InviteCode -__all__ = ["User", "Team", "Tree", "Session", "Attachment"] +__all__ = ["User", "Team", "Tree", "Session", "Attachment", "InviteCode"] diff --git a/backend/app/models/invite_code.py b/backend/app/models/invite_code.py new file mode 100644 index 00000000..4f2c6615 --- /dev/null +++ b/backend/app/models/invite_code.py @@ -0,0 +1,86 @@ +import uuid +import secrets +import string +from datetime import datetime, timezone +from typing import Optional +from sqlalchemy import String, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + + +def generate_invite_code() -> str: + """Generate an 8-character alphanumeric invite code.""" + alphabet = string.ascii_uppercase + string.digits + # Remove confusing characters: 0, O, I, 1 + alphabet = alphabet.replace("0", "").replace("O", "").replace("I", "").replace("1", "") + return "".join(secrets.choice(alphabet) for _ in range(8)) + + +class InviteCode(Base): + __tablename__ = "invite_codes" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4 + ) + code: Mapped[str] = mapped_column( + String(16), + unique=True, + nullable=False, + index=True, + default=generate_invite_code + ) + created_by_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id"), + nullable=False + ) + used_by_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id"), + nullable=True + ) + expires_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True + ) + note: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc) + ) + used_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True + ) + + # Relationships + created_by: Mapped["User"] = relationship( + "User", + foreign_keys=[created_by_id], + backref="created_invite_codes" + ) + used_by: Mapped[Optional["User"]] = relationship( + "User", + foreign_keys=[used_by_id], + backref="used_invite_code" + ) + + @property + def is_used(self) -> bool: + """Check if the invite code has been used.""" + return self.used_by_id is not None + + @property + def is_expired(self) -> bool: + """Check if the invite code has expired.""" + if self.expires_at is None: + return False + return datetime.now(timezone.utc) > self.expires_at + + @property + def is_valid(self) -> bool: + """Check if the invite code is valid (not used and not expired).""" + return not self.is_used and not self.is_expired diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 04e7844a..37b411fa 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -24,6 +24,11 @@ class User(Base): ForeignKey("teams.id"), nullable=True ) + invite_code_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("invite_codes.id"), + nullable=True + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) diff --git a/backend/app/schemas/invite_code.py b/backend/app/schemas/invite_code.py new file mode 100644 index 00000000..1e812c82 --- /dev/null +++ b/backend/app/schemas/invite_code.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field + + +class InviteCodeCreate(BaseModel): + """Schema for creating a new invite code.""" + expires_at: Optional[datetime] = Field(None, description="Optional expiration time") + note: Optional[str] = Field(None, max_length=255, description="Note about who this code is for") + + +class InviteCodeResponse(BaseModel): + """Schema for invite code response.""" + id: UUID + code: str + created_by_id: UUID + used_by_id: Optional[UUID] = None + expires_at: Optional[datetime] = None + note: Optional[str] = None + created_at: datetime + used_at: Optional[datetime] = None + is_used: bool + is_expired: bool + is_valid: bool + + class Config: + from_attributes = True + + +class InviteCodeValidation(BaseModel): + """Schema for invite code validation response.""" + valid: bool + message: str diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 2383afd0..3e10f6c0 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -12,6 +12,7 @@ class UserBase(BaseModel): class UserCreate(UserBase): password: str = Field(..., min_length=10, description="Password must be at least 10 characters") role: str = Field(default="engineer", description="User role: admin, engineer, or viewer") + invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)") class UserUpdate(BaseModel): diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 09bfba47..c58253bf 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -2,3 +2,4 @@ export { default as apiClient } from './client' export { default as authApi } from './auth' export { default as treesApi } from './trees' export { default as sessionsApi } from './sessions' +export { default as inviteApi } from './invite' diff --git a/frontend/src/api/invite.ts b/frontend/src/api/invite.ts new file mode 100644 index 00000000..f548321e --- /dev/null +++ b/frontend/src/api/invite.ts @@ -0,0 +1,11 @@ +import apiClient from './client' +import type { InviteCodeValidation } from '@/types' + +export const inviteApi = { + async validateCode(code: string): Promise { + const response = await apiClient.get(`/invites/validate/${code}`) + return response.data + }, +} + +export default inviteApi diff --git a/frontend/src/pages/RegisterPage.tsx b/frontend/src/pages/RegisterPage.tsx index d9c3adaf..03d421e3 100644 --- a/frontend/src/pages/RegisterPage.tsx +++ b/frontend/src/pages/RegisterPage.tsx @@ -1,23 +1,55 @@ import { useState } from 'react' import { Link, useNavigate } from 'react-router-dom' import { useAuthStore } from '@/store/authStore' +import { inviteApi } from '@/api' import { cn } from '@/lib/utils' export function RegisterPage() { const navigate = useNavigate() const { register, isLoading, error, clearError } = useAuthStore() + const [inviteCode, setInviteCode] = useState('') + const [inviteCodeStatus, setInviteCodeStatus] = useState<'idle' | 'checking' | 'valid' | 'invalid'>('idle') + const [inviteCodeMessage, setInviteCodeMessage] = useState('') const [name, setName] = useState('') const [email, setEmail] = useState('') const [password, setPassword] = useState('') const [confirmPassword, setConfirmPassword] = useState('') const [localError, setLocalError] = useState('') + const validateInviteCode = async (code: string) => { + if (!code.trim()) { + setInviteCodeStatus('idle') + setInviteCodeMessage('') + return + } + + setInviteCodeStatus('checking') + try { + const result = await inviteApi.validateCode(code.trim()) + setInviteCodeStatus(result.valid ? 'valid' : 'invalid') + setInviteCodeMessage(result.message) + } catch { + setInviteCodeStatus('invalid') + setInviteCodeMessage('Failed to validate invite code') + } + } + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault() setLocalError('') clearError() + if (!inviteCode.trim()) { + setLocalError('Invite code is required') + return + } + + if (inviteCodeStatus !== 'valid') { + setLocalError('Please enter a valid invite code') + return + } + if (!name || !email || !password) { setLocalError('Please fill in all fields') return @@ -34,7 +66,7 @@ export function RegisterPage() { } try { - await register({ email, password, name }) + await register({ email, password, name, invite_code: inviteCode.trim() }) navigate('/trees', { replace: true }) } catch { // Error is set in the store @@ -57,6 +89,43 @@ export function RegisterPage() { )} +
+ + { + setInviteCode(e.target.value.toUpperCase()) + setInviteCodeStatus('idle') + }} + onBlur={(e) => validateInviteCode(e.target.value)} + className={cn( + 'mt-1 block w-full rounded-md border bg-background px-3 py-2 font-mono tracking-wider', + 'text-foreground placeholder:text-muted-foreground', + 'focus:outline-none focus:ring-1', + inviteCodeStatus === 'valid' && 'border-green-500 focus:border-green-500 focus:ring-green-500', + inviteCodeStatus === 'invalid' && 'border-destructive focus:border-destructive focus:ring-destructive', + inviteCodeStatus === 'idle' && 'border-input focus:border-primary focus:ring-primary', + inviteCodeStatus === 'checking' && 'border-input focus:border-primary focus:ring-primary' + )} + placeholder="ABCD1234" + /> + {inviteCodeStatus === 'checking' && ( +

Validating...

+ )} + {inviteCodeStatus === 'valid' && ( +

{inviteCodeMessage}

+ )} + {inviteCodeStatus === 'invalid' && ( +

{inviteCodeMessage}

+ )} +
+