diff --git a/crates/core/src/route_handler.rs b/crates/core/src/route_handler.rs index db6c91b..e079bb7 100644 --- a/crates/core/src/route_handler.rs +++ b/crates/core/src/route_handler.rs @@ -56,6 +56,20 @@ pub enum HandlerAction { NeedsBody(PendingRequest), } +impl HandlerAction { + /// Returns a mutable reference to the response headers, if available. + /// + /// Returns `Some(&mut HeaderMap)` for `Response` and `Forward` variants, + /// `None` for `NeedsBody` (which has no response yet). + pub fn response_headers_mut(&mut self) -> Option<&mut HeaderMap> { + match self { + HandlerAction::Response(result) => Some(&mut result.headers), + HandlerAction::Forward(fwd) => Some(&mut fwd.headers), + HandlerAction::NeedsBody(_) => None, + } + } +} + /// A presigned URL request for the runtime to execute. pub struct ForwardRequest { /// HTTP method for the backend request. @@ -271,6 +285,53 @@ pub trait RouteHandler: MaybeSend + MaybeSync { mod tests { use super::*; + #[test] + fn test_response_headers_mut_on_response() { + let mut action = HandlerAction::Response(ProxyResult { + status: 200, + headers: HeaderMap::new(), + body: ProxyResponseBody::Empty, + }); + let headers = action.response_headers_mut().unwrap(); + headers.insert("x-custom", "value".parse().unwrap()); + if let HandlerAction::Response(result) = &action { + assert_eq!(result.headers.get("x-custom").unwrap(), "value"); + } + } + + #[test] + fn test_response_headers_mut_on_forward() { + let mut action = HandlerAction::Forward(ForwardRequest { + method: Method::GET, + url: "https://example.com".parse().unwrap(), + headers: HeaderMap::new(), + request_id: String::new(), + }); + assert!(action.response_headers_mut().is_some()); + } + + #[test] + fn test_response_headers_mut_on_needs_body() { + use crate::types::{BucketConfig, S3Operation}; + let mut action = HandlerAction::NeedsBody(PendingRequest { + operation: S3Operation::CreateMultipartUpload { + bucket: "b".into(), + key: "k".into(), + }, + bucket_config: BucketConfig { + name: String::new(), + backend_type: "s3".into(), + backend_prefix: None, + anonymous_access: false, + backend_options: Default::default(), + allowed_roles: Vec::new(), + }, + original_headers: HeaderMap::new(), + request_id: String::new(), + }); + assert!(action.response_headers_mut().is_none()); + } + #[test] fn test_blocks_hop_by_hop_headers() { let mut headers = http::HeaderMap::new();