From 143c979975314c78fdf3f56ed7eec92e09ffcf3f Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 03:18:00 -0400 Subject: [PATCH] feat(auth): add oauth_identities table for Google/Microsoft sign-in Co-Authored-By: Claude Opus 4.7 --- .../b1fad5ddf357_add_oauth_identities.py | 39 +++++++++++++++++++ backend/app/models/__init__.py | 2 + backend/app/models/oauth_identity.py | 36 +++++++++++++++++ backend/tests/test_oauth_identity_model.py | 39 +++++++++++++++++++ 4 files changed, 116 insertions(+) create mode 100644 backend/alembic/versions/b1fad5ddf357_add_oauth_identities.py create mode 100644 backend/app/models/oauth_identity.py create mode 100644 backend/tests/test_oauth_identity_model.py diff --git a/backend/alembic/versions/b1fad5ddf357_add_oauth_identities.py b/backend/alembic/versions/b1fad5ddf357_add_oauth_identities.py new file mode 100644 index 00000000..da4242c7 --- /dev/null +++ b/backend/alembic/versions/b1fad5ddf357_add_oauth_identities.py @@ -0,0 +1,39 @@ +"""add oauth_identities + +Revision ID: b1fad5ddf357 +Revises: c0f3a4b7e91d +Create Date: 2026-05-06 07:17:11.374555 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = 'b1fad5ddf357' +down_revision: Union[str, None] = 'c0f3a4b7e91d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "oauth_identities", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + sa.Column("provider", sa.String(20), nullable=False), + sa.Column("provider_subject", sa.String(255), nullable=False), + sa.Column("provider_email_at_link", sa.String(255), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.UniqueConstraint("provider", "provider_subject", name="uq_oauth_identities_provider_subject"), + ) + op.create_index("ix_oauth_identities_user_id", "oauth_identities", ["user_id"]) + + +def downgrade() -> None: + op.drop_index("ix_oauth_identities_user_id", table_name="oauth_identities") + op.drop_table("oauth_identities") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 52d90e19..e130224b 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -62,6 +62,7 @@ from .session_fact import SessionFact from .session_suggested_fix import SessionSuggestedFix from .draft_template import DraftTemplate from .account_settings import AccountSettings +from .oauth_identity import OAuthIdentity # noqa: F401 __all__ = [ "User", @@ -138,4 +139,5 @@ __all__ = [ "SessionSuggestedFix", "DraftTemplate", "AccountSettings", + "OAuthIdentity", ] diff --git a/backend/app/models/oauth_identity.py b/backend/app/models/oauth_identity.py new file mode 100644 index 00000000..07c4dbf4 --- /dev/null +++ b/backend/app/models/oauth_identity.py @@ -0,0 +1,36 @@ +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, UniqueConstraint, Index +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.user import User + + +class OAuthIdentity(Base): + __tablename__ = "oauth_identities" + __table_args__ = ( + UniqueConstraint("provider", "provider_subject", name="uq_oauth_identities_provider_subject"), + Index("ix_oauth_identities_user_id", "user_id"), + ) + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + provider: Mapped[str] = mapped_column(String(20), nullable=False) + provider_subject: Mapped[str] = mapped_column(String(255), nullable=False) + provider_email_at_link: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + user: Mapped["User"] = relationship("User", backref="oauth_identities") diff --git a/backend/tests/test_oauth_identity_model.py b/backend/tests/test_oauth_identity_model.py new file mode 100644 index 00000000..cccaf7ec --- /dev/null +++ b/backend/tests/test_oauth_identity_model.py @@ -0,0 +1,39 @@ +import uuid + +import pytest +from sqlalchemy import select + +from app.models.oauth_identity import OAuthIdentity + + +@pytest.mark.asyncio +async def test_oauth_identity_unique_provider_subject(test_db, test_user): + """Two rows with same provider+subject should violate uniqueness.""" + user_id = uuid.UUID(test_user["user_data"]["id"]) + + row1 = OAuthIdentity( + user_id=user_id, + provider="google", + provider_subject="abc-123", + provider_email_at_link="alex@acmemsp.com", + ) + test_db.add(row1) + await test_db.commit() + + row2 = OAuthIdentity( + user_id=user_id, + provider="google", + provider_subject="abc-123", + provider_email_at_link="alex@acmemsp.com", + ) + test_db.add(row2) + with pytest.raises(Exception): # IntegrityError + await test_db.commit() + await test_db.rollback() + + rows = ( + await test_db.execute( + select(OAuthIdentity).where(OAuthIdentity.user_id == user_id) + ) + ).scalars().all() + assert len(rows) == 1