diff --git a/lib/wreq.rb b/lib/wreq.rb index da498ff..b204f0b 100644 --- a/lib/wreq.rb +++ b/lib/wreq.rb @@ -25,7 +25,7 @@ module Wreq # @param method [Wreq::Method] HTTP method to use # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -56,7 +56,7 @@ def self.request(method, url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -87,7 +87,7 @@ def self.get(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -118,7 +118,7 @@ def self.head(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -149,7 +149,7 @@ def self.post(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -180,7 +180,7 @@ def self.put(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -211,7 +211,7 @@ def self.delete(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -242,7 +242,7 @@ def self.options(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -273,7 +273,7 @@ def self.trace(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value diff --git a/lib/wreq_ruby/client.rb b/lib/wreq_ruby/client.rb index 8f4771a..7718c01 100644 --- a/lib/wreq_ruby/client.rb +++ b/lib/wreq_ruby/client.rb @@ -38,6 +38,8 @@ class Client # @param headers [Wreq::Headers, Hash{String=>String}, nil] Default headers to include # in every request. Header names are case-insensitive. These headers # can be overridden on a per-request basis. + # @param orig_headers [Array, nil] Original header names used to + # preserve raw header order and HTTP/1 case-sensitive header handling. # # @param referer [Boolean, nil] Whether to automatically send Referer # headers when following redirects. When true, the previous URL will @@ -237,7 +239,7 @@ def self.new(**options) # @param method [Wreq::Method] HTTP method to use # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -268,7 +270,7 @@ def request(method, url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -299,7 +301,7 @@ def get(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -330,7 +332,7 @@ def head(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -361,7 +363,7 @@ def post(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -392,7 +394,7 @@ def put(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -423,7 +425,7 @@ def delete(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -454,7 +456,7 @@ def options(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value @@ -485,7 +487,7 @@ def trace(url, **options) # # @param url [String] Target URL # @param headers [Wreq::Headers, Hash{String=>String}, nil] Custom headers for this request - # @param orig_headers [Hash{String=>String}, nil] Original headers (raw, unmodified) + # @param orig_headers [Array, nil] Original header names used to preserve raw header order and HTTP/1 case-sensitive header handling # @param default_headers [Hash{String=>String}, nil] Default headers to merge # @param query [Hash, nil] URL query parameters # @param auth [String, nil] Authorization header value diff --git a/src/client.rs b/src/client.rs index e945bfa..436c3a0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,10 +10,7 @@ use magnus::{ Module, Object, RHash, RModule, Ruby, TryConvert, Value, function, method, typed_data::Obj, }; use serde::Deserialize; -use wreq::{ - Proxy, - header::{HeaderValue, OrigHeaderMap}, -}; +use wreq::{Proxy, header::HeaderValue}; use crate::{ client::{req::execute_request, resp::Response}, @@ -22,7 +19,7 @@ use crate::{ error::wreq_error_to_magnus, extractor::Extractor, gvl, - header::Headers, + header::{Headers, OrigHeaders}, http::Method, }; @@ -40,7 +37,7 @@ struct Builder { headers: Option, /// The original headers to use for the client. #[serde(skip)] - orig_headers: Option, + orig_headers: Option, /// Whether to use referer. referer: Option, /// Whether to allow redirects. @@ -143,6 +140,10 @@ impl Builder { builder.headers = Some(Headers::try_convert(v)?); } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(orig_headers))) { + builder.orig_headers = Some(OrigHeaders::try_convert(v)?); + } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(cookie_provider))) { builder.cookie_provider = Some((*Obj::::try_convert(v)?).clone()); } @@ -152,7 +153,6 @@ impl Builder { } builder.user_agent = Extractor::::try_convert(*keyword)?.into_inner(); - builder.orig_headers = Extractor::::try_convert(*keyword)?.into_inner(); builder.proxy = Extractor::::try_convert(*keyword)?.into_inner(); Ok(builder) @@ -175,14 +175,19 @@ impl Client { // User agent options. apply_option!(set_if_some, builder, params.user_agent, user_agent); - // Default headers options. + // Headers options. apply_option!( set_if_some_into_inner, builder, params.headers, default_headers ); - apply_option!(set_if_some, builder, params.orig_headers, orig_headers); + apply_option!( + set_if_some_inner, + builder, + params.orig_headers, + orig_headers + ); // Allow redirects options. apply_option!(set_if_some, builder, params.referer, referer); diff --git a/src/client/req.rs b/src/client/req.rs index 6cd4791..45b4547 100644 --- a/src/client/req.rs +++ b/src/client/req.rs @@ -3,7 +3,7 @@ use std::{net::IpAddr, time::Duration}; use http::header; use magnus::{RHash, TryConvert, typed_data::Obj, value::ReprValue}; use serde::Deserialize; -use wreq::{Client, Proxy, header::OrigHeaderMap}; +use wreq::{Client, Proxy}; use super::body::{Body, Form, Json}; use crate::{ @@ -12,7 +12,7 @@ use crate::{ emulate::Emulation, error::wreq_error_to_magnus, extractor::Extractor, - header::Headers, + header::{Headers, OrigHeaders}, http::{Method, Version}, rt, }; @@ -54,7 +54,7 @@ pub struct Request { /// The original headers to use for the request. #[serde(skip)] - orig_headers: Option, + orig_headers: Option, /// The cookies to use for the request. #[serde(skip)] @@ -120,6 +120,10 @@ impl Request { builder.headers = Some(Headers::try_convert(v)?); } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(orig_headers))) { + builder.orig_headers = Some(OrigHeaders::try_convert(v)?); + } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(cookies))) { builder.cookies = Some(Cookies::try_convert(v)?); } @@ -129,7 +133,6 @@ impl Request { } builder.proxy = Extractor::::try_convert(keyword)?.into_inner(); - builder.orig_headers = Extractor::::try_convert(keyword)?.into_inner(); Ok(builder) } @@ -179,7 +182,12 @@ pub fn execute_request>( // Headers options. apply_option!(set_if_some_into_inner, builder, request.headers, headers); - apply_option!(set_if_some, builder, request.orig_headers, orig_headers); + apply_option!( + set_if_some_inner, + builder, + request.orig_headers, + orig_headers + ); apply_option!( set_if_some, builder, diff --git a/src/error.rs b/src/error.rs index 6cf9521..80802ea 100644 --- a/src/error.rs +++ b/src/error.rs @@ -117,6 +117,14 @@ pub fn header_value_error_to_magnus(err: wreq::header::InvalidHeaderValue) -> Ma ) } +/// Map type/value errors to corresponding [`magnus::Error`] +pub fn type_value_error_to_magnus(err: &str) -> MagnusError { + MagnusError::new( + ruby!().get_inner(&BUILDER_ERROR), + format!("type error: {err}"), + ) +} + /// Map [`wreq::Error`] to corresponding [`magnus::Error`] pub fn wreq_error_to_magnus(err: wreq::Error) -> MagnusError { let error_msg = err.to_string(); diff --git a/src/extractor.rs b/src/extractor.rs index e16d548..63df29a 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,6 +1,6 @@ use magnus::{RArray, RHash, RString, Ruby, TryConvert, r_hash::ForEach}; use wreq::{ - Proxy, Version, + Proxy, header::{HeaderMap, HeaderName, HeaderValue, OrigHeaderMap}, }; @@ -32,27 +32,6 @@ where } } -// ===== impl Extractor ===== - -impl ExtractorName for Version { - const NAME: &str = "version"; -} - -impl TryConvert for Extractor { - fn try_convert(value: magnus::Value) -> Result { - let keyword = RHash::try_convert(value)?; - if let Some(version_val) = keyword.get(Version::NAME) { - return <&crate::http::Version>::try_convert(version_val) - .cloned() - .map(crate::http::Version::into_ffi) - .map(Some) - .map(Extractor); - } - - Ok(Extractor(None)) - } -} - // ===== impl Extractor ===== impl ExtractorName for HeaderValue { diff --git a/src/header.rs b/src/header.rs index aa1cf33..cf8bd6f 100644 --- a/src/header.rs +++ b/src/header.rs @@ -9,8 +9,11 @@ use magnus::{ r_hash::ForEach, typed_data::{Inspect, Obj}, }; +use wreq::header::OrigHeaderMap; -use crate::error::{header_name_error_to_magnus, header_value_error_to_magnus}; +use crate::error::{ + header_name_error_to_magnus, header_value_error_to_magnus, type_value_error_to_magnus, +}; /// HTTP headers collection with read and write operations. /// @@ -20,6 +23,9 @@ use crate::error::{header_name_error_to_magnus, header_value_error_to_magnus}; #[magnus::wrap(class = "Wreq::Headers", free_immediately, size)] pub struct Headers(pub RefCell); +/// A map from header names to their original casing as received in an HTTP message. +pub struct OrigHeaders(pub OrigHeaderMap); + struct HeaderIter { inner: http::header::IntoIter, next_name: Option, @@ -166,6 +172,23 @@ impl TryConvert for Headers { } } +// ===== impl OrigHeaders ===== + +impl TryConvert for OrigHeaders { + fn try_convert(value: magnus::Value) -> Result { + let mut map = OrigHeaderMap::new(); + + let rarray = RArray::from_value(value) + .ok_or_else(|| type_value_error_to_magnus("Expected an array of strings"))?; + + for value in rarray.into_iter().flat_map(RString::from_value) { + map.insert(value.to_bytes()); + } + + Ok(Self(map)) + } +} + // ===== impl HeaderIter ===== impl Iterator for HeaderIter { diff --git a/test/orig_header_test.rb b/test/orig_header_test.rb new file mode 100644 index 0000000..a45b420 --- /dev/null +++ b/test/orig_header_test.rb @@ -0,0 +1,115 @@ +require "test_helper" + +class OrigHeaderTest < Minitest::Test + URL = "https://tls.browserleaks.com/http1" + + CASES = [ + { + name: "mixed_case_descending", + headers: { + "X-Zeta-Token" => "zeta", + "x-alpha-key" => "alpha", + "X-MiXeD-CaSe" => "mixed" + }, + orig_headers: ["X-Zeta-Token", "x-alpha-key", "X-MiXeD-CaSe"] + }, + { + name: "reverse_alpha_order", + headers: { + "X-Third" => "3", + "X-Second" => "2", + "X-First" => "1" + }, + orig_headers: ["X-Third", "X-Second", "X-First"] + }, + { + name: "preserve_weird_casing", + headers: { + "x-a" => "a", + "X-B" => "b", + "x-C" => "c" + }, + orig_headers: ["x-C", "x-a", "X-B"] + }, + { + name: "interleaved_tokens", + headers: { + "X-Token-3" => "v3", + "X-Token-1" => "v1", + "X-Token-2" => "v2" + }, + orig_headers: ["X-Token-1", "X-Token-2", "X-Token-3"] + } + ].freeze + + def test_client_default_orig_headers_preserves_header_order_in_multiple_shuffled_cases + CASES.each do |kase| + client = Wreq::Client.new( + headers: kase[:headers], + orig_headers: kase[:orig_headers] + ) + + response = client.get(URL, version: Wreq::Version::HTTP_11) + assert_equal 200, response.code, "case=#{kase[:name]}" + + echoed_headers = extract_http1_headers(response.json, kase[:name]) + assert_header_order(echoed_headers, kase[:orig_headers], kase[:name]) + assert_header_values(echoed_headers, kase[:headers], kase[:name]) + end + end + + def test_module_request_orig_headers_preserves_header_order_in_multiple_shuffled_cases + CASES.each do |kase| + response = Wreq.get( + URL, + headers: kase[:headers], + orig_headers: kase[:orig_headers], + version: Wreq::Version::HTTP_11 + ) + assert_equal 200, response.code, "case=#{kase[:name]}" + + echoed_headers = extract_http1_headers(response.json, kase[:name]) + assert_header_order(echoed_headers, kase[:orig_headers], kase[:name]) + assert_header_values(echoed_headers, kase[:headers], kase[:name]) + end + end + + private + + def extract_http1_headers(json, case_name) + http1 = fetch_by_name(json, "http1") + refute_nil http1, "case=#{case_name}: expected JSON key 'http1', got #{json.keys.inspect}" + + headers = fetch_by_name(http1, "headers") + refute_nil headers, "case=#{case_name}: expected JSON key 'http1.headers'" + headers + end + + def fetch_by_name(hash_like, key_name) + return hash_like[key_name] if hash_like.respond_to?(:key?) && hash_like.key?(key_name) + return hash_like[key_name.to_sym] if hash_like.respond_to?(:key?) && hash_like.key?(key_name.to_sym) + + pair = hash_like.find { |k, _| k.to_s == key_name } + pair&.last + end + + def assert_header_order(echoed_headers, ordered_names, case_name) + echoed_keys = echoed_headers.keys + positions = ordered_names.map do |expected_name| + index = echoed_keys.index(expected_name) + refute_nil index, "case=#{case_name}: expected header to exist in echo: #{expected_name}" + index + end + + assert_equal positions.sort, positions, + "case=#{case_name}: expected header order #{ordered_names.inspect}, got keys #{echoed_keys.inspect}" + end + + def assert_header_values(echoed_headers, expected_headers, case_name) + expected_headers.each do |name, expected_value| + assert echoed_headers.key?(name), + "case=#{case_name}: expected exact-case header name #{name}, got #{echoed_headers.keys.inspect}" + assert_equal expected_value, echoed_headers[name] + end + end +end