Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ edition = "2021"

[features]
default = ["shuttle-axum", "shuttle-runtime"]
local = ["dotenv", "tokio"]
local = ["dotenv", "tokio", "dummy-auth"]
dummy-auth = []

[dependencies]
axum = { version = "0.8.1", features = ["macros", "multipart"] }
Expand All @@ -22,5 +23,6 @@ tokio = {version = "1.44.2", optional = true, features = ["full"] }
utoipa = { version = "5.3.1", features = ["axum_extras", "chrono", "uuid", "decimal"] }
utoipa-axum = "0.2.0"
utoipa-swagger-ui = { version = "9.0.1", features = ["axum"] }
tracing = "0.1.41"
uuid = { version = "1.14.0", features = ["v4", "fast-rng", "serde"] }

75 changes: 75 additions & 0 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use axum::{
extract::FromRequestParts,
http::{header, request::Parts, StatusCode},
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use tracing::error;

/// Represents an authenticated user, extracted from the Authentication header.
#[derive(Debug, Clone)]
pub struct AuthUser {
/// The ID of the authenticated user.
pub user_id: String, // Or Uuid if you prefer and add the uuid crate
}

/// Custom rejection type for authentication errors.
#[derive(Debug, Serialize)]
pub struct AuthError {
message: String,
}

impl IntoResponse for AuthError {
fn into_response(self) -> Response {
(StatusCode::UNAUTHORIZED, Json(self)).into_response()
}
}

impl<S> FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = AuthError;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the Authentication header
let auth_header = parts
.headers
.get(header::AUTHORIZATION) // Using standard AUTHORIZATION header for convention
.and_then(|value| value.to_str().ok());

match auth_header {
Some(header_value) => {
// Check if the header starts with "Bearer "
if let Some(token) = header_value.strip_prefix("Bearer ") {
let token = token.trim();
if token.is_empty() {
error!("Bearer token is empty");
Err(AuthError {
message: "Invalid Bearer token".to_string(),
})
} else {
// In a real scenario, you'd validate this token.
// For this dummy implementation, the token itself is the user_id.
Ok(AuthUser {
user_id: token.to_string(),
})
}
} else {
error!("Invalid authentication header format. Expected Bearer token.");
Err(AuthError {
message: "Invalid authentication header format. Expected Bearer token."
.to_string(),
})
}
}
None => {
error!("Authorization header missing");
Err(AuthError {
message: "Authorization header required".to_string(),
})
}
}
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ pub mod types;
pub mod user;
pub mod utils;

#[cfg(feature = "dummy-auth")]
pub mod auth;

pub mod tags {
pub static USER: &str = "User";
pub static TAGS: &str = "Tags";
Expand Down
35 changes: 19 additions & 16 deletions src/user/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#[cfg(feature = "dummy-auth")]
use crate::auth::AuthUser;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use firebase_auth::FirebaseUser;
#[cfg(not(feature = "dummy-auth"))]
use firebase_auth::FirebaseUser as AuthUser;
use serde::{Deserialize, Serialize};
use sqlx::types::chrono::{DateTime, Utc};
use utoipa::{IntoParams, ToSchema};
Expand Down Expand Up @@ -123,7 +126,7 @@ pub struct NemesisTag {
)]
pub async fn get_current_user(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
) -> Result<Json<User>, Error> {
let user_id = user.user_id;

Expand Down Expand Up @@ -177,7 +180,7 @@ pub async fn get_current_user(
pub async fn get_user(
State(state): State<ArchenemyState>,
Path(user_id): Path<String>,
_user: FirebaseUser, // Ensure authenticated
_user: AuthUser, // Ensure authenticated
) -> Result<Json<User>, Error> {
let maybe_user = utils::get_user_by_id(&state.pool, &user_id).await?;

Expand Down Expand Up @@ -225,7 +228,7 @@ pub async fn get_user(
)]
pub async fn update_current_user(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Json(update): Json<UpdateUserRequest>,
) -> Result<Json<User>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -270,7 +273,7 @@ pub async fn update_current_user(
)]
pub async fn get_all_tags(
State(state): State<ArchenemyState>,
_user: FirebaseUser, // Ensure authenticated
_user: AuthUser, // Ensure authenticated
) -> Result<Json<Vec<TagCount>>, Error> {
let tags = utils::get_all_tags(&state.pool).await?;
Ok(Json(tags))
Expand Down Expand Up @@ -309,7 +312,7 @@ pub async fn get_all_tags(
)]
pub async fn get_nemesis_tags(
State(state): State<ArchenemyState>,
_user: FirebaseUser, // Ensure authenticated
_user: AuthUser, // Ensure authenticated
Path(tag_name): Path<String>,
Query(pagination): Query<PaginationParams>,
) -> Result<Json<Vec<NemesisTag>>, Error> {
Expand Down Expand Up @@ -355,7 +358,7 @@ pub async fn get_nemesis_tags(
pub async fn get_user_tags(
State(state): State<ArchenemyState>,
Path(user_id): Path<String>,
_user: FirebaseUser, // Ensure authenticated
_user: AuthUser, // Ensure authenticated
) -> Result<Json<Vec<UserTag>>, Error> {
// Verify user exists
let maybe_user = utils::get_user_by_id(&state.pool, &user_id).await?;
Expand Down Expand Up @@ -395,7 +398,7 @@ pub async fn get_user_tags(
)]
pub async fn get_current_user_tags(
State(state): State<ArchenemyState>,
user: FirebaseUser, // Ensure authenticated
user: AuthUser, // Ensure authenticated
) -> Result<Json<Vec<UserTag>>, Error> {
// Verify user exists
let maybe_user = utils::get_user_by_id(&state.pool, &user.user_id).await?;
Expand Down Expand Up @@ -443,7 +446,7 @@ pub async fn get_current_user_tags(
)]
pub async fn add_current_user_tag(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Json(request): Json<AddTagRequest>,
) -> Result<Json<UserTag>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -491,7 +494,7 @@ pub async fn add_current_user_tag(
)]
pub async fn remove_current_user_tags(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Json(names): Json<Vec<String>>,
) -> Result<StatusCode, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -531,7 +534,7 @@ pub async fn remove_current_user_tags(
)]
pub async fn get_potential_nemeses(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Query(pagination): Query<PaginationParams>,
) -> Result<Json<Vec<UserWithTags>>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -590,7 +593,7 @@ pub async fn get_potential_nemeses(
)]
pub async fn like_user(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Path(target_user_id): Path<String>,
) -> Result<Json<UserLike>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -649,7 +652,7 @@ pub async fn like_user(
)]
pub async fn dislike_user(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Path(target_user_id): Path<String>,
) -> Result<Json<UserDislike>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -712,7 +715,7 @@ pub async fn dislike_user(
)]
pub async fn dislike_user_with_tags(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Path(target_user_id): Path<String>,
Json(request): Json<AddTagsRequest>,
) -> Result<Json<Vec<UserDislikeTag>>, Error> {
Expand Down Expand Up @@ -779,7 +782,7 @@ pub async fn dislike_user_with_tags(
)]
pub async fn get_liked_users(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Query(pagination): Query<PaginationParams>,
) -> Result<Json<Vec<UserWithLikedAt>>, Error> {
let user_id = user.user_id;
Expand Down Expand Up @@ -828,7 +831,7 @@ pub async fn get_liked_users(
)]
pub async fn get_disliked_users(
State(state): State<ArchenemyState>,
user: FirebaseUser,
user: AuthUser,
Query(pagination): Query<PaginationParams>,
) -> Result<Json<Vec<UserWithDislikedAt>>, Error> {
let user_id = user.user_id;
Expand Down
Loading