feat(auth): add oauth_identities table for Google/Microsoft sign-in

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 03:18:00 -04:00
parent ab0d40c1e2
commit 143c979975
4 changed files with 116 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
"""add oauth_identities
Revision ID: b1fad5ddf357
Revises: c0f3a4b7e91d
Create Date: 2026-05-06 07:17:11.374555
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = 'b1fad5ddf357'
down_revision: Union[str, None] = 'c0f3a4b7e91d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"oauth_identities",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("provider", sa.String(20), nullable=False),
sa.Column("provider_subject", sa.String(255), nullable=False),
sa.Column("provider_email_at_link", sa.String(255), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.UniqueConstraint("provider", "provider_subject", name="uq_oauth_identities_provider_subject"),
)
op.create_index("ix_oauth_identities_user_id", "oauth_identities", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_oauth_identities_user_id", table_name="oauth_identities")
op.drop_table("oauth_identities")

View File

@@ -62,6 +62,7 @@ from .session_fact import SessionFact
from .session_suggested_fix import SessionSuggestedFix
from .draft_template import DraftTemplate
from .account_settings import AccountSettings
from .oauth_identity import OAuthIdentity # noqa: F401
__all__ = [
"User",
@@ -138,4 +139,5 @@ __all__ = [
"SessionSuggestedFix",
"DraftTemplate",
"AccountSettings",
"OAuthIdentity",
]

View File

@@ -0,0 +1,36 @@
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, UniqueConstraint, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
class OAuthIdentity(Base):
__tablename__ = "oauth_identities"
__table_args__ = (
UniqueConstraint("provider", "provider_subject", name="uq_oauth_identities_provider_subject"),
Index("ix_oauth_identities_user_id", "user_id"),
)
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
provider: Mapped[str] = mapped_column(String(20), nullable=False)
provider_subject: Mapped[str] = mapped_column(String(255), nullable=False)
provider_email_at_link: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
user: Mapped["User"] = relationship("User", backref="oauth_identities")

View File

@@ -0,0 +1,39 @@
import uuid
import pytest
from sqlalchemy import select
from app.models.oauth_identity import OAuthIdentity
@pytest.mark.asyncio
async def test_oauth_identity_unique_provider_subject(test_db, test_user):
"""Two rows with same provider+subject should violate uniqueness."""
user_id = uuid.UUID(test_user["user_data"]["id"])
row1 = OAuthIdentity(
user_id=user_id,
provider="google",
provider_subject="abc-123",
provider_email_at_link="alex@acmemsp.com",
)
test_db.add(row1)
await test_db.commit()
row2 = OAuthIdentity(
user_id=user_id,
provider="google",
provider_subject="abc-123",
provider_email_at_link="alex@acmemsp.com",
)
test_db.add(row2)
with pytest.raises(Exception): # IntegrityError
await test_db.commit()
await test_db.rollback()
rows = (
await test_db.execute(
select(OAuthIdentity).where(OAuthIdentity.user_id == user_id)
)
).scalars().all()
assert len(rows) == 1