-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.py
More file actions
108 lines (85 loc) · 2.97 KB
/
auth.py
File metadata and controls
108 lines (85 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from ast import mod
from datetime import UTC, datetime, timedelta
from unittest import result
import jwt
from fastapi.security import OAuth2PasswordBearer
from pwdlib import PasswordHash
import hashlib
import secrets
from config import settings
from typing import Annotated
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from sqlalchemy import select
import models
password_hash = PasswordHash.recommended()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/users/token")
def hash_password(password: str) -> str:
return password_hash.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return password_hash.verify(plain_password, hashed_password)
def generate_reset_token() -> str:
return secrets.token_urlsafe(32)
def hash_reset_token(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""Create a JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
minutes=settings.access_token_expire_minutes,
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode,
settings.secret_key.get_secret_value(),
algorithm=settings.algorithm,
)
return encoded_jwt
def verify_access_token(token: str) -> str | None:
"""Verify a JWT access token and return the subject (user id) if valid."""
try:
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.algorithm],
options={"require": ["exp", "sub"]},
)
except jwt.InvalidTokenError:
return None
else:
return payload.get("sub")
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Annotated[AsyncSession, Depends(get_db)]
) -> models.User:
user_id = verify_access_token(token)
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user_id_int = int(user_id)
except (TypeError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
result = await db.execute(
select(models.User).where(models.User.id == user_id_int)
)
user = result.scalars().first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
return user
CurrentUser = Annotated[models.User, Depends(get_current_user)]