diff --git a/lib/wreq_ruby/http.rb b/lib/wreq_ruby/http.rb index 5930f66..443f5eb 100644 --- a/lib/wreq_ruby/http.rb +++ b/lib/wreq_ruby/http.rb @@ -49,6 +49,20 @@ class Version def to_s end end + + # Compares HTTP versions by semantic value, not object identity. + # + # This method is implemented by the native extension. + # When comparing with non-{Wreq::Version} objects, it returns false. + # + # @param other [Object] object to compare against + # @return [Boolean] true when both represent the same HTTP version + # @example + # Wreq::Version::HTTP_11 == response.version + unless method_defined?(:==) + def ==(other) + end + end end # HTTP status code wrapper. diff --git a/src/client/req.rs b/src/client/req.rs index 002a338..6cd4791 100644 --- a/src/client/req.rs +++ b/src/client/req.rs @@ -3,7 +3,7 @@ use std::{net::IpAddr, time::Duration}; use http::header; use magnus::{RHash, TryConvert, typed_data::Obj, value::ReprValue}; use serde::Deserialize; -use wreq::{Client, Proxy, Version, header::OrigHeaderMap}; +use wreq::{Client, Proxy, header::OrigHeaderMap}; use super::body::{Body, Form, Json}; use crate::{ @@ -13,7 +13,7 @@ use crate::{ error::wreq_error_to_magnus, extractor::Extractor, header::Headers, - http::Method, + http::{Method, Version}, rt, }; @@ -112,6 +112,10 @@ impl Request { builder.emulation = Some((*obj).clone()); } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(version))) { + builder.version = Some(Version::try_convert(v)?); + } + if let Some(v) = hash.get(ruby.to_symbol(stringify!(headers))) { builder.headers = Some(Headers::try_convert(v)?); } @@ -125,7 +129,6 @@ impl Request { } builder.proxy = Extractor::::try_convert(keyword)?.into_inner(); - builder.version = Extractor::::try_convert(keyword)?.into_inner(); builder.orig_headers = Extractor::::try_convert(keyword)?.into_inner(); Ok(builder) @@ -145,7 +148,13 @@ pub fn execute_request>( apply_option!(set_if_some_inner, builder, request.emulation, emulation); // Version options. - apply_option!(set_if_some, builder, request.version, version); + apply_option!( + set_if_some_map, + builder, + request.version, + version, + Version::into_ffi + ); // Timeout options. apply_option!( diff --git a/src/http.rs b/src/http.rs index 5bb02b0..31cb7d2 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,4 @@ -use magnus::{Error, Module, RModule, Ruby, method, typed_data::Inspect}; +use magnus::{Error, Module, RModule, Ruby, TryConvert, Value, method, typed_data::Inspect}; define_ruby_enum!( /// An HTTP version. @@ -41,6 +41,20 @@ impl Version { pub fn to_s(&self) -> String { self.into_ffi().inspect() } + + /// Value-based equality for Ruby (`==`). + #[inline] + pub fn equals(&self, other: Value) -> bool { + <&Version>::try_convert(other) + .map(|other| *self == *other) + .unwrap_or(false) + } +} + +impl TryConvert for Version { + fn try_convert(value: magnus::Value) -> Result { + <&Version>::try_convert(value).cloned() + } } // ===== impl StatusCode ===== @@ -113,6 +127,7 @@ pub fn include(ruby: &Ruby, gem_module: &RModule) -> Result<(), Error> { version_class.const_set("HTTP_2", Version::HTTP_2)?; version_class.const_set("HTTP_3", Version::HTTP_3)?; version_class.define_method("to_s", method!(Version::to_s, 0))?; + version_class.define_method("==", method!(Version::equals, 1))?; let status_code_class = gem_module.define_class("StatusCode", ruby.class_object())?; status_code_class.define_method("as_int", method!(StatusCode::as_int, 0))?; diff --git a/test/request_test.rb b/test/request_test.rb index ac9cbaf..64d377c 100644 --- a/test/request_test.rb +++ b/test/request_test.rb @@ -6,6 +6,16 @@ def setup @client = Wreq::Client.new(timeout: 30) end + def test_spec_http1_version + response = Wreq.get("https://tls.browserleaks.com", version: Wreq::Version::HTTP_11) + assert_equal response.version, Wreq::Version::HTTP_11 + end + + def test_spec_http2_version + response = Wreq.get("https://tls.browserleaks.com", version: Wreq::Version::HTTP_2) + assert_equal response.version, Wreq::Version::HTTP_2 + end + def test_module_get_method response = Wreq.get("http://localhost:8080/get") assert_equal 200, response.code