Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ This is intentional behavior.

Alice works with Go 1.0 and higher.

### Bob
Just as Alice provides a convenient mechanism of chaining HTTP handlers, Bob provides the same capability for the client-side HTTP round trippers (`http.RoundTripper`).

For all intents and purposes, Bob works exactly the same way as Alice does. Replace all references to `http.Handler` with `http.RoundTripper`, and that's how Bob works.

The only difference is introducing `bob.RoundTripperFunc` (akin to `http.HandlerFunc`) as it is not provided by the Go framework.

### Contributing

0. Find an issue that bugs you / open a new one.
Expand Down
122 changes: 122 additions & 0 deletions bob/bob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Package bob provides a convenient way to chain http round trippers.
package bob

import "net/http"

// RoundTripperFunc is to RoundTripper what HandlerFunc is to Handler.
// It is a higher-order function that enables chaining of RoundTrippers
// with the middleware pattern.
type RoundTripperFunc func(*http.Request) (*http.Response, error)

// RoundTrip calls the function itself.
func (f RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

// Constructor for a piece of middleware.
// Some middleware uses this constructor out of the box,
// so in most cases you can just pass somepackage.New
type Constructor func(http.RoundTripper) http.RoundTripper

// Chain acts as a list of http.RoundTripper constructors.
// Chain is effectively immutable:
// once created, it will always hold
// the same set of constructors in the same order.
type Chain struct {
constructors []Constructor
}

// New creates a new chain,
// memorizing the given list of middleware constructors.
// New serves no other function,
// constructors are only called upon a call to Then().
func New(constructors ...Constructor) Chain {
return Chain{append(([]Constructor)(nil), constructors...)}
}

// Then chains the middleware and returns the final http.RoundTripper.
// New(m1, m2, m3).Then(rt)
// is equivalent to:
// m1(m2(m3(rt)))
// When the request goes out, it will be passed to m1, then m2, then m3
// and finally, the given round tripper
// (assuming every middleware calls the following one).
//
// A chain can be safely reused by calling Then() several times.
// stdStack := chain.New(ratelimitHandler, csrfHandler)
// indexPipe = stdStack.Then(indexHandler)
// authPipe = stdStack.Then(authHandler)
// Note that constructors are called on every call to Then()
// and thus several instances of the same middleware will be created
// when a chain is reused in this way.
// For proper middleware, this should cause no problems.
//
// Then() treats nil as http.DefaultTransport.
func (c Chain) Then(rt http.RoundTripper) http.RoundTripper {
if rt == nil {
rt = http.DefaultTransport
}

for i := range c.constructors {
rt = c.constructors[len(c.constructors)-1-i](rt)
}

return rt
}

// ThenFunc works identically to Then, but takes
// a RoundTripperFunc instead of a RoundTripper.
//
// The following two statements are equivalent:
// c.Then(http.RoundTripperFunc(fn))
// c.ThenFunc(fn)
//
// RoundTripperFunc provides all the guarantees of Then.
func (c Chain) ThenFunc(fn RoundTripperFunc) http.RoundTripper {
if fn == nil {
return c.Then(nil)
}
return c.Then(fn)
}

// Append extends a chain, adding the specified constructors
// as the last ones in the request flow.
//
// Append returns a new chain, leaving the original one untouched.
//
// stdChain := chain.New(m1, m2)
// extChain := stdChain.Append(m3, m4)
// // requests in stdChain go m1 -> m2
// // requests in extChain go m1 -> m2 -> m3 -> m4
func (c Chain) Append(constructors ...Constructor) Chain {
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
newCons = append(newCons, c.constructors...)
newCons = append(newCons, constructors...)

return Chain{newCons}
}

// Extend extends a chain by adding the specified chain
// as the last one in the request flow.
//
// Extend returns a new chain, leaving the original one untouched.
//
// stdChain := chain.New(m1, m2)
// ext1Chain := chain.New(m3, m4)
// ext2Chain := stdChain.Extend(ext1Chain)
// // requests in stdChain go m1 -> m2
// // requests in ext1Chain go m3 -> m4
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
//
// Another example:
// aHtmlAfterNosurf := chain.New(m2)
// aHtml := chain.New(m1, func(rt http.RoundTripper) http.RoundTripper {
// csrf := nosurf.New(rt)
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
// return csrf
// }).Extend(aHtmlAfterNosurf)
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-roundtripper
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
}
221 changes: 221 additions & 0 deletions bob/bob_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Package bob provides a convenient way to chain http round trippers.
package bob

import (
"bytes"
"io/ioutil"
"net/http"
"reflect"
"testing"
)

// A constructor for middleware
// that writes its own "tag" into the request body and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Constructor {
return func(rt http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
err := appendTag(tag, r)
if err != nil {
return nil, err
}
return rt.RoundTrip(r)
})
}
}

func appendTag(tag string, r *http.Request) error {
var newBody []byte
if r.Body == nil {
newBody = []byte(tag)
} else {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
newBody = append(body, []byte(tag)...)
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(newBody))
return nil
}

// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
val1 := reflect.ValueOf(f1)
val2 := reflect.ValueOf(f2)
return val1.Pointer() == val2.Pointer()
}

var testApp = RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
appendTag("app\n", r)
return &http.Response{}, nil
})

func TestNew(t *testing.T) {
c1 := func(h http.RoundTripper) http.RoundTripper {
return nil
}

c2 := func(h http.RoundTripper) http.RoundTripper {
return http.DefaultTransport
}

slice := []Constructor{c1, c2}

chain := New(slice...)
for k := range slice {
if !funcsEqual(chain.constructors[k], slice[k]) {
t.Error("New does not add constructors correctly")
}
}
}

func TestThenWorksWithNoMiddleware(t *testing.T) {
if !funcsEqual(New().Then(testApp), testApp) {
t.Error("Then does not work with no middleware")
}
}

func TestThenTreatsNilAsDefaultTransport(t *testing.T) {
if New().Then(nil) != http.DefaultTransport {
t.Error("Then does not treat nil as DefaultTransport")
}
}

func TestThenFuncTreatsNilAsDefaultTransport(t *testing.T) {
if New().ThenFunc(nil) != http.DefaultTransport {
t.Error("ThenFunc does not treat nil as DefaultTransport")
}
}

func TestThenFuncConstructsRoundTripperFunc(t *testing.T) {
fn := RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{}, nil
})
chained := New().ThenFunc(fn)

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.RoundTrip(r)

if reflect.TypeOf(chained) != reflect.TypeOf((RoundTripperFunc)(nil)) {
t.Error("ThenFunc does not construct RoundTripperFunc")
}
}

func bodyAsString(r *http.Request) (string, error) {
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return "", err
}
return string(body[:]), nil
}

func TestThenOrdersRoundTrippersCorrectly(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")

chained := New(t1, t2, t3).Then(testApp)

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.RoundTrip(r)

body, err := bodyAsString(r)
if err != nil {
t.Fatal(err)
}
if body != "t1\nt2\nt3\napp\n" {
t.Error("Then does not order round trippers correctly")
}
}

func TestAppendAddsRoundTrippersCorrectly(t *testing.T) {
chain := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))

if len(chain.constructors) != 2 {
t.Error("chain should have 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should have 4 constructors")
}

chained := newChain.Then(testApp)

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.RoundTrip(r)

body, err := bodyAsString(r)
if err != nil {
t.Fatal(err)
}
if body != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Append does not add round trippers correctly")
}
}

func TestAppendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
newChain := chain.Append(tagMiddleware(""))

if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Apppend does not respect immutability")
}
}

func TestExtendAddsRoundTrippersCorrectly(t *testing.T) {
chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
newChain := chain1.Extend(chain2)

if len(chain1.constructors) != 2 {
t.Error("chain1 should contain 2 constructors")
}
if len(chain2.constructors) != 2 {
t.Error("chain2 should contain 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should contain 4 constructors")
}

chained := newChain.Then(testApp)

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.RoundTrip(r)

body, err := bodyAsString(r)
if err != nil {
t.Fatal(err)
}
if body != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Extend does not add round trippers in correctly")
}
}

func TestExtendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
newChain := chain.Extend(New(tagMiddleware("")))

if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Extend does not respect immutability")
}
}