Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 87 additions & 46 deletions crates/tauri/src/protocol/tauri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,48 @@ pub fn get<R: Runtime>(
};

let window_origin = window_origin.to_string();
let web_resource_request_handler = web_resource_request_handler.map(Arc::new);

#[cfg(all(dev, mobile))]
let response_cache = Arc::new(Mutex::new(HashMap::new()));
let response_cache = Arc::new(Mutex::new(HashMap::<String, CachedResponse>::new()));

Box::new(move |_, request, responder| {
match get_response(
request,
&manager,
&window_origin,
web_resource_request_handler.as_deref(),
#[cfg(all(dev, mobile))]
(&url, &response_cache),
) {
Ok(response) => responder.respond(response),
Err(e) => responder.respond(
HttpResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
.header("Access-Control-Allow-Origin", &window_origin)
.body(e.to_string().into_bytes())
.unwrap(),
),
}
let manager = manager.clone();
let window_origin = window_origin.clone();
let web_resource_request_handler = web_resource_request_handler.clone();

#[cfg(all(dev, mobile))]
let url = url.clone();
#[cfg(all(dev, mobile))]
let response_cache = response_cache.clone();

crate::async_runtime::spawn(async move {
match get_response(
request,
&manager,
&window_origin,
web_resource_request_handler.as_deref().map(|h| &**h),
#[cfg(all(dev, mobile))]
(&url, &response_cache),
)
.await
{
Ok(response) => responder.respond(response),
Err(e) => responder.respond(
HttpResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
.header("Access-Control-Allow-Origin", &window_origin)
.body(e.to_string().into_bytes())
.unwrap(),
),
}
});
})
}

fn get_response<R: Runtime>(
#[allow(unused_mut)] mut request: Request<Vec<u8>>,
async fn get_response<R: Runtime>(
request: Request<Vec<u8>>,
#[allow(unused_variables)] manager: &AppManager<R>,
window_origin: &str,
web_resource_request_handler: Option<&WebResourceRequestHandler>,
Expand Down Expand Up @@ -119,10 +133,12 @@ fn get_response<R: Runtime>(
let _ = rustls::crypto::ring::default_provider().install_default();
}

#[allow(unused_mut)]
let mut client = reqwest::ClientBuilder::new();

if url.starts_with("https://") {
// we can't load env vars at runtime, gotta embed them in the lib
#[allow(unused_variables)]
if let Some(cert_pem) = option_env!("TAURI_DEV_ROOT_CERTIFICATE") {
#[cfg(any(
feature = "native-tls",
Expand Down Expand Up @@ -157,39 +173,61 @@ fn get_response<R: Runtime>(
.build()
.unwrap()
.request(request.method().clone(), &url);
proxy_builder = proxy_builder.body(std::mem::take(request.body_mut()));
for (name, value) in request.headers() {
proxy_builder = proxy_builder.header(name, value);
}
proxy_builder = proxy_builder.body(request.body().clone());
match crate::async_runtime::safe_block_on(proxy_builder.send()) {
Ok(r) => {
let mut response_cache_ = response_cache.lock().unwrap();
let mut response = None;
if r.status() == http::StatusCode::NOT_MODIFIED {
response = response_cache_.get(&url);
}
let response = if let Some(r) = response {
r

match async {
let r = proxy_builder.send().await?;
let status = r.status();
let headers = r.headers().clone();

Ok::<_, reqwest::Error>(if status == http::StatusCode::NOT_MODIFIED {
if let Some(response) = response_cache.lock().unwrap().get(&url).cloned() {
for (name, value) in &response.headers {
builder = builder.header(name, value);
}

builder
.status(response.status)
.body(response.body.to_vec().into())
.unwrap()
} else {
let status = r.status();
let headers = r.headers().clone();
let body = crate::async_runtime::safe_block_on(r.bytes())?;
let response = CachedResponse {
status,
headers,
body,
};
response_cache_.insert(url.clone(), response);
response_cache_.get(&url).unwrap()
for (name, value) in &headers {
builder = builder.header(name, value);
}

builder.status(status).body(Vec::new().into()).unwrap()
}
} else {
let body = r.bytes().await?;
let response = CachedResponse {
status,
headers,
body,
};

{
response_cache
.lock()
.unwrap()
.insert(url.clone(), response.clone());
}

for (name, value) in &response.headers {
builder = builder.header(name, value);
}

builder
.status(response.status)
.body(response.body.to_vec().into())?
}
.body(response.body.to_vec().into())
.unwrap()
})
}
.await
{
Ok(response) => response,
Err(e) => {
let error_message = format!(
"Failed to request {}: {}{}",
Expand All @@ -211,15 +249,18 @@ fn get_response<R: Runtime>(

#[cfg(not(all(dev, mobile)))]
let mut response = {
let use_https_scheme = request.uri().scheme() == Some(&http::uri::Scheme::HTTPS);
let asset = manager.get_asset(path, use_https_scheme)?;
let asset = manager.get_asset(
path,
request.uri().scheme() == Some(&http::uri::Scheme::HTTPS),
)?;
builder = builder.header(CONTENT_TYPE, &asset.mime_type);
if let Some(csp) = &asset.csp_header {
builder = builder.header("Content-Security-Policy", csp);
}
builder.body(asset.bytes.into())?
};
if let Some(handler) = &web_resource_request_handler {

if let Some(handler) = web_resource_request_handler {
handler(request, &mut response);
}

Expand Down
Loading