Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/openshell-sandbox/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ webpki-roots = { workspace = true }
# HTTP
bytes = { workspace = true }

# Encoding
base64 = { workspace = true }

# IP network / CIDR parsing
ipnet = "2"

Expand Down
344 changes: 343 additions & 1 deletion crates/openshell-sandbox/src/secrets.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use base64::Engine as _;
use std::collections::HashMap;

const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:";
Expand Down Expand Up @@ -35,17 +36,59 @@ impl SecretResolver {
}

pub(crate) fn rewrite_header_value(&self, value: &str) -> Option<String> {
// Direct placeholder match: `x-api-key: openshell:resolve:env:KEY`
if let Some(secret) = self.resolve_placeholder(value.trim()) {
return Some(secret.to_string());
}

let trimmed = value.trim();

// Basic auth decoding: `Basic <base64>` where the decoded content
// contains a placeholder (e.g. `user:openshell:resolve:env:PASS`).
// Decode, rewrite placeholders in the decoded string, re-encode.
if let Some(encoded) = trimmed.strip_prefix("Basic ").map(str::trim) {
if let Some(rewritten) = self.rewrite_basic_auth_token(encoded) {
return Some(format!("Basic {rewritten}"));
}
}

// Prefixed placeholder: `Bearer openshell:resolve:env:KEY`
let split_at = trimmed.find(char::is_whitespace)?;
let prefix = &trimmed[..split_at];
let candidate = trimmed[split_at..].trim();
let secret = self.resolve_placeholder(candidate)?;
Some(format!("{prefix} {secret}"))
}

/// Decode a Base64-encoded Basic auth token, resolve any placeholders in
/// the decoded `username:password` string, and re-encode.
///
/// Returns `None` if decoding fails or no placeholders are found.
fn rewrite_basic_auth_token(&self, encoded: &str) -> Option<String> {
let b64 = base64::engine::general_purpose::STANDARD;
let decoded_bytes = b64.decode(encoded.trim()).ok()?;
let decoded = std::str::from_utf8(&decoded_bytes).ok()?;

// Check if the decoded string contains any placeholder
if !decoded.contains(PLACEHOLDER_PREFIX) {
return None;
}

// Rewrite all placeholder occurrences in the decoded string
let mut rewritten = decoded.to_string();
for (placeholder, secret) in &self.by_placeholder {
if rewritten.contains(placeholder.as_str()) {
rewritten = rewritten.replace(placeholder.as_str(), secret);
}
}

// Only return if we actually changed something
if rewritten == decoded {
return None;
}

Some(b64.encode(rewritten.as_bytes()))
}
}

pub(crate) fn placeholder_for_env_key(key: &str) -> String {
Expand All @@ -68,7 +111,7 @@ pub(crate) fn rewrite_http_header_block(raw: &[u8], resolver: Option<&SecretReso
};

let mut output = Vec::with_capacity(raw.len());
output.extend_from_slice(request_line.as_bytes());
output.extend_from_slice(rewrite_request_line(request_line, resolver).as_bytes());
output.extend_from_slice(b"\r\n");

for line in lines {
Expand Down Expand Up @@ -96,6 +139,117 @@ pub(crate) fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> Stri
}
}

/// Rewrite credential placeholders in the request line's URL query parameters.
///
/// Given a request line like `GET /api?key=openshell:resolve:env:API_KEY HTTP/1.1`,
/// resolves placeholders in query parameter values and percent-encodes the
/// resolved secret. Handles URLs with multiple query parameters and preserves
/// parameters that don't contain placeholders.
fn rewrite_request_line(line: &str, resolver: &SecretResolver) -> String {
// Request line format: METHOD SP REQUEST-URI SP HTTP-VERSION
let mut parts = line.splitn(3, ' ');
let method = match parts.next() {
Some(m) => m,
None => return line.to_string(),
};
let uri = match parts.next() {
Some(u) => u,
None => return line.to_string(),
};
let version = match parts.next() {
Some(v) => v,
None => return line.to_string(),
};

// Only rewrite if the URI contains a placeholder
if !uri.contains(PLACEHOLDER_PREFIX) {
return line.to_string();
}

let rewritten_uri = rewrite_uri_query_params(uri, resolver);
format!("{method} {rewritten_uri} {version}")
}

/// Rewrite placeholders in query parameter values of a URI.
///
/// Splits the URI at `?`, parses key=value pairs from the query string,
/// resolves any placeholder values, and percent-encodes the resolved secrets.
/// Parameters without placeholders are preserved verbatim.
fn rewrite_uri_query_params(uri: &str, resolver: &SecretResolver) -> String {
let Some((path, query)) = uri.split_once('?') else {
return uri.to_string();
};

let mut rewritten_params = Vec::new();
for param in query.split('&') {
if let Some((key, value)) = param.split_once('=') {
// Percent-decode the value before checking for placeholder
let decoded_value = percent_decode(value);
if let Some(secret) = resolver.resolve_placeholder(&decoded_value) {
rewritten_params.push(format!("{key}={}", percent_encode(secret)));
} else {
rewritten_params.push(param.to_string());
}
} else {
rewritten_params.push(param.to_string());
}
}

format!("{path}?{}", rewritten_params.join("&"))
}

/// Percent-encode a string for safe use in URL query parameter values.
///
/// Encodes all characters except unreserved characters (RFC 3986 Section 2.3):
/// ALPHA / DIGIT / "-" / "." / "_" / "~"
fn percent_encode(input: &str) -> String {
let mut encoded = String::with_capacity(input.len());
for byte in input.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
encoded.push(byte as char);
}
_ => {
encoded.push_str(&format!("%{byte:02X}"));
}
}
}
encoded
}

/// Percent-decode a URL-encoded string.
fn percent_decode(input: &str) -> String {
let mut decoded = Vec::with_capacity(input.len());
let mut bytes = input.bytes();
while let Some(b) = bytes.next() {
if b == b'%' {
let hi = bytes.next();
let lo = bytes.next();
if let (Some(h), Some(l)) = (hi, lo) {
let hex = [h, l];
if let Ok(s) = std::str::from_utf8(&hex) {
if let Ok(val) = u8::from_str_radix(s, 16) {
decoded.push(val);
continue;
}
}
// Invalid percent encoding — preserve verbatim
decoded.push(b'%');
decoded.push(h);
decoded.push(l);
} else {
decoded.push(b'%');
if let Some(h) = hi {
decoded.push(h);
}
}
} else {
decoded.push(b);
}
}
String::from_utf8_lossy(&decoded).into_owned()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -259,4 +413,192 @@ mod tests {
let rewritten = rewrite_http_header_block(raw, None);
assert_eq!(raw.as_slice(), rewritten.as_slice());
}

// --- Query parameter rewriting tests ---

#[test]
fn rewrites_query_param_placeholder_in_request_line() {
let (child_env, resolver) = SecretResolver::from_provider_env(
[("YOUTUBE_API_KEY".to_string(), "AIzaSy-secret".to_string())]
.into_iter()
.collect(),
);
let placeholder = child_env.get("YOUTUBE_API_KEY").unwrap();

let raw = format!(
"GET /youtube/v3/search?part=snippet&key={placeholder} HTTP/1.1\r\n\
Host: www.googleapis.com\r\n\r\n"
);
let rewritten = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref());
let rewritten = String::from_utf8(rewritten).expect("utf8");

assert!(
rewritten.starts_with("GET /youtube/v3/search?part=snippet&key=AIzaSy-secret HTTP/1.1\r\n"),
"Expected query param rewritten, got: {rewritten}"
);
assert!(!rewritten.contains("openshell:resolve:env:"));
}

#[test]
fn rewrites_query_param_with_special_chars_percent_encoded() {
let (child_env, resolver) = SecretResolver::from_provider_env(
[("API_KEY".to_string(), "key with spaces&symbols=yes".to_string())]
.into_iter()
.collect(),
);
let placeholder = child_env.get("API_KEY").unwrap();

let raw = format!(
"GET /api?token={placeholder} HTTP/1.1\r\nHost: x\r\n\r\n"
);
let rewritten = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref());
let rewritten = String::from_utf8(rewritten).expect("utf8");

// Secret should be percent-encoded
assert!(
rewritten.contains("token=key%20with%20spaces%26symbols%3Dyes"),
"Expected percent-encoded secret, got: {rewritten}"
);
}

#[test]
fn rewrites_query_param_only_placeholder_first_param() {
let (child_env, resolver) = SecretResolver::from_provider_env(
[("KEY".to_string(), "secret123".to_string())]
.into_iter()
.collect(),
);
let placeholder = child_env.get("KEY").unwrap();

let raw = format!(
"GET /api?key={placeholder}&format=json HTTP/1.1\r\nHost: x\r\n\r\n"
);
let rewritten = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref());
let rewritten = String::from_utf8(rewritten).expect("utf8");

assert!(
rewritten.starts_with("GET /api?key=secret123&format=json HTTP/1.1"),
"Expected first param rewritten, got: {rewritten}"
);
}

#[test]
fn no_query_param_rewrite_without_placeholder() {
let (_, resolver) = SecretResolver::from_provider_env(
[("KEY".to_string(), "secret".to_string())]
.into_iter()
.collect(),
);

let raw = b"GET /api?key=normalvalue HTTP/1.1\r\nHost: x\r\n\r\n";
let rewritten = rewrite_http_header_block(raw, resolver.as_ref());
assert_eq!(raw.as_slice(), rewritten.as_slice());
}

// --- Basic Authorization header encoding tests ---

#[test]
fn rewrites_basic_auth_placeholder_in_decoded_token() {
use base64::Engine as _;
let b64 = base64::engine::general_purpose::STANDARD;

let (child_env, resolver) = SecretResolver::from_provider_env(
[("DB_PASSWORD".to_string(), "s3cret!".to_string())]
.into_iter()
.collect(),
);
let resolver = resolver.expect("resolver");
let placeholder = child_env.get("DB_PASSWORD").unwrap();

// Simulate: agent constructs Basic auth with placeholder password
let credentials = format!("admin:{placeholder}");
let encoded = b64.encode(credentials.as_bytes());

let header_line = format!("Authorization: Basic {encoded}");
let rewritten = rewrite_header_line(&header_line, &resolver);

// Decode the rewritten token to verify
let rewritten_token = rewritten.strip_prefix("Authorization: Basic ").unwrap();
let decoded = b64.decode(rewritten_token).unwrap();
let decoded_str = std::str::from_utf8(&decoded).unwrap();

assert_eq!(decoded_str, "admin:s3cret!");
assert!(!rewritten.contains("openshell:resolve:env:"));
}

#[test]
fn basic_auth_without_placeholder_unchanged() {
let (_, resolver) = SecretResolver::from_provider_env(
[("KEY".to_string(), "secret".to_string())]
.into_iter()
.collect(),
);
let resolver = resolver.expect("resolver");

// Normal Basic auth token without any placeholder
use base64::Engine as _;
let b64 = base64::engine::general_purpose::STANDARD;
let encoded = b64.encode(b"user:password");
let header_line = format!("Authorization: Basic {encoded}");

let rewritten = rewrite_header_line(&header_line, &resolver);
assert_eq!(rewritten, header_line, "Should not modify non-placeholder Basic auth");
}

#[test]
fn basic_auth_full_round_trip_header_block() {
use base64::Engine as _;
let b64 = base64::engine::general_purpose::STANDARD;

let (child_env, resolver) = SecretResolver::from_provider_env(
[("REGISTRY_PASS".to_string(), "hunter2".to_string())]
.into_iter()
.collect(),
);
let placeholder = child_env.get("REGISTRY_PASS").unwrap();
let credentials = format!("deploy:{placeholder}");
let encoded = b64.encode(credentials.as_bytes());

let raw = format!(
"GET /v2/_catalog HTTP/1.1\r\n\
Host: registry.example.com\r\n\
Authorization: Basic {encoded}\r\n\
Accept: application/json\r\n\r\n"
);

let rewritten = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref());
let rewritten = String::from_utf8(rewritten).expect("utf8");

// Extract and decode the rewritten Basic token
let auth_line = rewritten.lines().find(|l| l.starts_with("Authorization:")).unwrap();
let token = auth_line.strip_prefix("Authorization: Basic ").unwrap();
let decoded = b64.decode(token).unwrap();
assert_eq!(std::str::from_utf8(&decoded).unwrap(), "deploy:hunter2");

// Other headers preserved
assert!(rewritten.contains("Host: registry.example.com\r\n"));
assert!(rewritten.contains("Accept: application/json\r\n"));
assert!(!rewritten.contains("openshell:resolve:env:"));
}

// --- Percent encoding tests ---

#[test]
fn percent_encode_preserves_unreserved() {
assert_eq!(percent_encode("abc123-._~"), "abc123-._~");
}

#[test]
fn percent_encode_encodes_special_chars() {
assert_eq!(percent_encode("a b"), "a%20b");
assert_eq!(percent_encode("key=val&x"), "key%3Dval%26x");
}

#[test]
fn percent_decode_round_trips() {
let original = "hello world & more=stuff";
let encoded = percent_encode(original);
let decoded = percent_decode(&encoded);
assert_eq!(decoded, original);
}
}
Loading