diff --git a/Cargo.lock b/Cargo.lock index 91475b9..a40a761 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,6 +77,7 @@ dependencies = [ "shuttle-runtime", "sqlx", "tokio", + "tracing", "utoipa", "utoipa-axum", "utoipa-swagger-ui", diff --git a/Cargo.toml b/Cargo.toml index df06ded..79fd70e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,8 @@ edition = "2021" [features] default = ["shuttle-axum", "shuttle-runtime"] -local = ["dotenv", "tokio"] +local = ["dotenv", "tokio", "dummy-auth"] +dummy-auth = [] [dependencies] axum = { version = "0.8.1", features = ["macros", "multipart"] } @@ -22,5 +23,6 @@ tokio = {version = "1.44.2", optional = true, features = ["full"] } utoipa = { version = "5.3.1", features = ["axum_extras", "chrono", "uuid", "decimal"] } utoipa-axum = "0.2.0" utoipa-swagger-ui = { version = "9.0.1", features = ["axum"] } +tracing = "0.1.41" uuid = { version = "1.14.0", features = ["v4", "fast-rng", "serde"] } diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..ae33b6e --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,75 @@ +use axum::{ + extract::FromRequestParts, + http::{header, request::Parts, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use serde::Serialize; +use tracing::error; + +/// Represents an authenticated user, extracted from the Authentication header. +#[derive(Debug, Clone)] +pub struct AuthUser { + /// The ID of the authenticated user. + pub user_id: String, // Or Uuid if you prefer and add the uuid crate +} + +/// Custom rejection type for authentication errors. +#[derive(Debug, Serialize)] +pub struct AuthError { + message: String, +} + +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + (StatusCode::UNAUTHORIZED, Json(self)).into_response() + } +} + +impl FromRequestParts for AuthUser +where + S: Send + Sync, +{ + type Rejection = AuthError; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + // Extract the Authentication header + let auth_header = parts + .headers + .get(header::AUTHORIZATION) // Using standard AUTHORIZATION header for convention + .and_then(|value| value.to_str().ok()); + + match auth_header { + Some(header_value) => { + // Check if the header starts with "Bearer " + if let Some(token) = header_value.strip_prefix("Bearer ") { + let token = token.trim(); + if token.is_empty() { + error!("Bearer token is empty"); + Err(AuthError { + message: "Invalid Bearer token".to_string(), + }) + } else { + // In a real scenario, you'd validate this token. + // For this dummy implementation, the token itself is the user_id. + Ok(AuthUser { + user_id: token.to_string(), + }) + } + } else { + error!("Invalid authentication header format. Expected Bearer token."); + Err(AuthError { + message: "Invalid authentication header format. Expected Bearer token." + .to_string(), + }) + } + } + None => { + error!("Authorization header missing"); + Err(AuthError { + message: "Authorization header required".to_string(), + }) + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index daf5a36..eae9670 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,9 @@ pub mod types; pub mod user; pub mod utils; +#[cfg(feature = "dummy-auth")] +pub mod auth; + pub mod tags { pub static USER: &str = "User"; pub static TAGS: &str = "Tags"; diff --git a/src/user/handlers.rs b/src/user/handlers.rs index 91b0609..f55ed88 100644 --- a/src/user/handlers.rs +++ b/src/user/handlers.rs @@ -1,9 +1,12 @@ +#[cfg(feature = "dummy-auth")] +use crate::auth::AuthUser; use axum::{ extract::{Path, Query, State}, http::StatusCode, Json, }; -use firebase_auth::FirebaseUser; +#[cfg(not(feature = "dummy-auth"))] +use firebase_auth::FirebaseUser as AuthUser; use serde::{Deserialize, Serialize}; use sqlx::types::chrono::{DateTime, Utc}; use utoipa::{IntoParams, ToSchema}; @@ -123,7 +126,7 @@ pub struct NemesisTag { )] pub async fn get_current_user( State(state): State, - user: FirebaseUser, + user: AuthUser, ) -> Result, Error> { let user_id = user.user_id; @@ -177,7 +180,7 @@ pub async fn get_current_user( pub async fn get_user( State(state): State, Path(user_id): Path, - _user: FirebaseUser, // Ensure authenticated + _user: AuthUser, // Ensure authenticated ) -> Result, Error> { let maybe_user = utils::get_user_by_id(&state.pool, &user_id).await?; @@ -225,7 +228,7 @@ pub async fn get_user( )] pub async fn update_current_user( State(state): State, - user: FirebaseUser, + user: AuthUser, Json(update): Json, ) -> Result, Error> { let user_id = user.user_id; @@ -270,7 +273,7 @@ pub async fn update_current_user( )] pub async fn get_all_tags( State(state): State, - _user: FirebaseUser, // Ensure authenticated + _user: AuthUser, // Ensure authenticated ) -> Result>, Error> { let tags = utils::get_all_tags(&state.pool).await?; Ok(Json(tags)) @@ -309,7 +312,7 @@ pub async fn get_all_tags( )] pub async fn get_nemesis_tags( State(state): State, - _user: FirebaseUser, // Ensure authenticated + _user: AuthUser, // Ensure authenticated Path(tag_name): Path, Query(pagination): Query, ) -> Result>, Error> { @@ -355,7 +358,7 @@ pub async fn get_nemesis_tags( pub async fn get_user_tags( State(state): State, Path(user_id): Path, - _user: FirebaseUser, // Ensure authenticated + _user: AuthUser, // Ensure authenticated ) -> Result>, Error> { // Verify user exists let maybe_user = utils::get_user_by_id(&state.pool, &user_id).await?; @@ -395,7 +398,7 @@ pub async fn get_user_tags( )] pub async fn get_current_user_tags( State(state): State, - user: FirebaseUser, // Ensure authenticated + user: AuthUser, // Ensure authenticated ) -> Result>, Error> { // Verify user exists let maybe_user = utils::get_user_by_id(&state.pool, &user.user_id).await?; @@ -443,7 +446,7 @@ pub async fn get_current_user_tags( )] pub async fn add_current_user_tag( State(state): State, - user: FirebaseUser, + user: AuthUser, Json(request): Json, ) -> Result, Error> { let user_id = user.user_id; @@ -491,7 +494,7 @@ pub async fn add_current_user_tag( )] pub async fn remove_current_user_tags( State(state): State, - user: FirebaseUser, + user: AuthUser, Json(names): Json>, ) -> Result { let user_id = user.user_id; @@ -531,7 +534,7 @@ pub async fn remove_current_user_tags( )] pub async fn get_potential_nemeses( State(state): State, - user: FirebaseUser, + user: AuthUser, Query(pagination): Query, ) -> Result>, Error> { let user_id = user.user_id; @@ -590,7 +593,7 @@ pub async fn get_potential_nemeses( )] pub async fn like_user( State(state): State, - user: FirebaseUser, + user: AuthUser, Path(target_user_id): Path, ) -> Result, Error> { let user_id = user.user_id; @@ -649,7 +652,7 @@ pub async fn like_user( )] pub async fn dislike_user( State(state): State, - user: FirebaseUser, + user: AuthUser, Path(target_user_id): Path, ) -> Result, Error> { let user_id = user.user_id; @@ -712,7 +715,7 @@ pub async fn dislike_user( )] pub async fn dislike_user_with_tags( State(state): State, - user: FirebaseUser, + user: AuthUser, Path(target_user_id): Path, Json(request): Json, ) -> Result>, Error> { @@ -779,7 +782,7 @@ pub async fn dislike_user_with_tags( )] pub async fn get_liked_users( State(state): State, - user: FirebaseUser, + user: AuthUser, Query(pagination): Query, ) -> Result>, Error> { let user_id = user.user_id; @@ -828,7 +831,7 @@ pub async fn get_liked_users( )] pub async fn get_disliked_users( State(state): State, - user: FirebaseUser, + user: AuthUser, Query(pagination): Query, ) -> Result>, Error> { let user_id = user.user_id;