Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions app/api/v1/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from authlib.integrations.starlette_client import OAuth
from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from starlette import status
Expand All @@ -15,6 +15,8 @@
from app.domains import UserService
from app.domains.auth.auth_service import AuthService
from app.literals.auth import OAuthProvider
from app.literals.users import OnboardingStep
from app.models import User
from app.schemas import LoginRequest, Token, UserRegister
from app.schemas.auth import (
MessageResponse,
Expand Down Expand Up @@ -77,31 +79,25 @@ async def logout_all(auth_service: AuthService = Depends(get_auth_service), toke
# but maybe this is better approach, I dunno
# @rate_limit(max_requests=5, window_seconds=3600, per_user=False)
async def signup(
request: Request,
data: UserRegister,
user_service: UserService = Depends(get_user_service),
auth_service: AuthService = Depends(get_auth_service),
request: Request,
data: UserRegister,
user_service: UserService = Depends(get_user_service),
auth_service: AuthService = Depends(get_auth_service),
):
"""
Register a new user, send verification email and auto-login.
"""
client_ip = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")

user_service.register(
data=data,
ip_address=client_ip,
user_agent=user_agent
)
user_service.register(data=data, ip_address=client_ip, user_agent=user_agent)
await auth_service.send_verification_email(data.email)
login_req = LoginRequest(
username=data.email,
password=data.password
)
login_req = LoginRequest(username=data.email, password=data.password)
token = await auth_service.authenticate_user(login_req)

return token


@router.post("/verify/send", response_model=MessageResponse)
@rate_limit(
max_requests=3,
Expand Down Expand Up @@ -198,6 +194,22 @@ async def change_password(
####################################################


@router.patch("/onboarding/step")
async def update_onboarding_step(
step: OnboardingStep,
current_user: TokenData = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Update user's current onboarding step."""
user = db.get(User, current_user.id)
if not user:
raise HTTPException(status_code=404, detail="User not found")

user.onboarding_step = step.value
db.commit()
return {"onboarding_step": step}


@router.get("/{provider}", response_class=RedirectResponse, include_in_schema=True)
async def login_oauth(
provider: OAuthProvider,
Expand Down
5 changes: 2 additions & 3 deletions app/core/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import atexit
import os
import tempfile

from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
Expand All @@ -9,8 +8,8 @@

db_fd, db_path = None, None
if settings.ENVIRONMENT == "dev":
db_fd, db_path = tempfile.mkstemp(suffix=".db")
print("Created database at {}".format(db_path))
db_path = "dev.db"
print(f"Using persistent database at {db_path}")
engine = create_engine(f"sqlite:///{db_path}", pool_pre_ping=True)
else:
engine = create_engine(
Expand Down
3 changes: 2 additions & 1 deletion app/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from pydantic import BaseModel

from app.literals.users import Role
from app.literals.users import OnboardingStep, Role


class TokenData(BaseModel):
id: uuid.UUID
username: str
email: str
role: Role = Role.BASIC
onboarding_step: str = OnboardingStep.NOT_STARTED
45 changes: 41 additions & 4 deletions app/domains/auth/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ async def oauth_callback(
oauth: OAuth,
) -> Token:
"""Handle OAuth callback and authenticate user."""
from app.core.middleware import get_client_ip
from app.domains.user.user_service import UserService

provider_str = provider.value
oauth_client = oauth.create_client(provider_str)
token = await oauth_client.authorize_access_token(request)
Expand All @@ -298,9 +301,33 @@ async def oauth_callback(
detail="Could not validate credentials",
)
email = user_info["email"]
first_name = user_info.get("given_name")
last_name = user_info.get("family_name")
avatar_url = user_info.get("picture")

elif provider == OAuthProvider.GITHUB:
email_resp = await oauth_client.get("user/emails", token=token)
email = next(e["email"] for e in email_resp.json() if e["primary"])
email_data = email_resp.json()
email = next((e["email"] for e in email_data if e.get("primary")), None)
if not email:
email = email_data[0]["email"] if email_data else None

if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No email found from GitHub provider",
)

user_resp = await oauth_client.get("user", token=token)
user_data = user_resp.json()
full_name = user_data.get("name") or user_data.get("login")
if " " in full_name:
first_name, last_name = full_name.split(" ", 1)
else:
first_name = full_name
last_name = ""
avatar_url = user_data.get("avatar_url")

else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -309,10 +336,20 @@ async def oauth_callback(

user = self.user_repository.get_by_email(email)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found. Please register first.",
client_ip = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
user_service = UserService(self.db)

user_service.register_oauth(
email=email,
first_name=first_name,
last_name=last_name,
avatar_url=avatar_url,
provider=provider_str,
ip_address=client_ip,
user_agent=user_agent,
)
user = self.user_repository.get_by_email(email)

if not user.is_active:
raise HTTPException(
Expand Down
114 changes: 99 additions & 15 deletions app/domains/user/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from typing import List, Optional

from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import Session
from starlette import status

from app.core.security import hash_password, verify_password
from app.core.utils import extract_constraint_info
from app.domains.terms.terms_repository import TermsRepository
from app.domains.user.user_repository import UserRepository
from app.literals.users import Role
from app.models import ConnectionTableModel, TermsTableModel, User, UserTermsAcceptanceTableModel
from app.models import ConnectionTableModel, User, UserTermsAcceptanceTableModel
from app.schemas.user import (
UserCreate,
UserDetail,
Expand All @@ -32,6 +32,7 @@ class UserService:
def __init__(self, db: Session):
self.db = db
self.repository = UserRepository(db)
self.terms_repository = TermsRepository(db)

def create_user(self, user_in: UserCreate) -> UserRead:
"""
Expand Down Expand Up @@ -264,10 +265,102 @@ def unban_user(self, user_id: uuid.UUID) -> None:

self.db.commit()

def _generate_unique_username(self, base_username: str) -> str:
"""Generate a unique username based on the provided one."""
username = base_username
counter = 1
while self.repository.exists_by_username(username):
suffix = str(counter)
if len(base_username) + len(suffix) > 50:
username = base_username[: 50 - len(suffix)] + suffix
else:
username = f"{base_username}{suffix}"
counter += 1
return username

def register_oauth(
self,
email: str,
first_name: str | None,
last_name: str | None,
avatar_url: str | None,
provider: str,
ip_address: str,
user_agent: str,
) -> UserRead:
"""
Register a new user via OAuth.
"""
email = email.lower()

if self.repository.exists_by_email(email):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already registered",
)

terms = self.terms_repository.get_latest()
if not terms:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No active terms found in the system.",
)

base_username = email.split("@")[0]
base_username = "".join(c for c in base_username if c.isalnum() or c in ".-_")
if not base_username:
base_username = f"user{random.randint(1000, 9999)}"

username = self._generate_unique_username(base_username)

new_referral_code = self._generate_referral_code()
random_password = "".join(random.choices(string.ascii_letters + string.digits + "!@#$%", k=32))
hashed_pw = hash_password(random_password)

try:
new_user = User(
username=username,
email=email,
password=hashed_pw,
first_name=first_name or "New",
last_name=last_name or "User",
avatar_url=avatar_url,
referral_code=new_referral_code,
created_ip=ip_address,
user_agent=user_agent,
role=Role.BASIC,
provider=provider,
is_active=True,
is_verified=True,
verified_at=datetime.datetime.now(datetime.UTC),
)
self.db.add(new_user)
self.db.flush()

acceptance = UserTermsAcceptanceTableModel(user_id=new_user.id, terms_id=terms.id)
self.db.add(acceptance)

connection = ConnectionTableModel(user_id=new_user.id, ip_address=ip_address)
self.db.add(connection)

self.db.commit()
self.db.refresh(new_user)

return UserRead.model_validate(new_user)

except IntegrityError as e:
self.db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=extract_constraint_info(e),
)
except Exception as e:
self.db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"OAuth registration failed: {str(e)}")

def _generate_referral_code(self, length=5) -> str:
while True:
code = "".join(random.choices(string.ascii_uppercase + string.digits, k=length))
# check DB for uniqueness
if not self.repository.get_by_referral_code(code):
return code

Expand All @@ -282,28 +375,24 @@ def register(self, data: UserRegister, ip_address: str, user_agent: str) -> User
5. Commit all or rollback.
"""

# Email uniqueness
if self.repository.exists_by_email(str(data.email)):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already registered",
)

# Username uniqueness
if self.repository.exists_by_username(data.username):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Username already taken",
)

# Accepting terms, search by the version
terms = self.db.scalar(select(TermsTableModel).where(TermsTableModel.version == data.accepted_terms_version))
terms = self.terms_repository.get_latest()
if not terms:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid terms version: {data.accepted_terms_version}"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No active terms found in the system."
)

# Optionally validating provided referral code
referrer_id = None
if data.referral_code:
referrer = self.repository.get_by_referral_code(data.referral_code)
Expand All @@ -314,12 +403,10 @@ def register(self, data: UserRegister, ip_address: str, user_agent: str) -> User
new_referral_code = self._generate_referral_code()
hashed_pw = hash_password(data.password)

# atomic transaction
try:
# A. creating User
new_user = User(
username=data.username,
email=data.email, # pydantic validator make it lower().
email=data.email,
password=hashed_pw,
first_name=data.first_name,
last_name=data.last_name,
Expand All @@ -336,15 +423,12 @@ def register(self, data: UserRegister, ip_address: str, user_agent: str) -> User
self.db.add(new_user)
self.db.flush()

# B. Creating record of accepted terms
acceptance = UserTermsAcceptanceTableModel(user_id=new_user.id, terms_id=terms.id)
self.db.add(acceptance)

# C. Saving connection log (ip, user)
connection = ConnectionTableModel(user_id=new_user.id, ip_address=ip_address)
self.db.add(connection)

# commit
self.db.commit()
self.db.refresh(new_user)

Expand Down
8 changes: 8 additions & 0 deletions app/literals/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ class Role(str, Enum):
Role.RECRUITER: 2,
Role.BASIC: 3,
}


class OnboardingStep(str, Enum):
NOT_STARTED = "not_started"
PERSONAL_INFO = "personal_info"
ACADEMIC_INFO = "academic_info"
PREFERENCES = "preferences"
COMPLETED = "completed"
Loading