diff --git a/README.md b/README.md index e4f9157..479c004 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,13 @@ This is intentional behavior. Alice works with Go 1.0 and higher. +### Bob +Just as Alice provides a convenient mechanism of chaining HTTP handlers, Bob provides the same capability for the client-side HTTP round trippers (`http.RoundTripper`). + +For all intents and purposes, Bob works exactly the same way as Alice does. Replace all references to `http.Handler` with `http.RoundTripper`, and that's how Bob works. + +The only difference is introducing `bob.RoundTripperFunc` (akin to `http.HandlerFunc`) as it is not provided by the Go framework. + ### Contributing 0. Find an issue that bugs you / open a new one. diff --git a/bob/bob.go b/bob/bob.go new file mode 100644 index 0000000..340b734 --- /dev/null +++ b/bob/bob.go @@ -0,0 +1,122 @@ +// Package bob provides a convenient way to chain http round trippers. +package bob + +import "net/http" + +// RoundTripperFunc is to RoundTripper what HandlerFunc is to Handler. +// It is a higher-order function that enables chaining of RoundTrippers +// with the middleware pattern. +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +// RoundTrip calls the function itself. +func (f RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +// Constructor for a piece of middleware. +// Some middleware uses this constructor out of the box, +// so in most cases you can just pass somepackage.New +type Constructor func(http.RoundTripper) http.RoundTripper + +// Chain acts as a list of http.RoundTripper constructors. +// Chain is effectively immutable: +// once created, it will always hold +// the same set of constructors in the same order. +type Chain struct { + constructors []Constructor +} + +// New creates a new chain, +// memorizing the given list of middleware constructors. +// New serves no other function, +// constructors are only called upon a call to Then(). +func New(constructors ...Constructor) Chain { + return Chain{append(([]Constructor)(nil), constructors...)} +} + +// Then chains the middleware and returns the final http.RoundTripper. +// New(m1, m2, m3).Then(rt) +// is equivalent to: +// m1(m2(m3(rt))) +// When the request goes out, it will be passed to m1, then m2, then m3 +// and finally, the given round tripper +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// stdStack := chain.New(ratelimitHandler, csrfHandler) +// indexPipe = stdStack.Then(indexHandler) +// authPipe = stdStack.Then(authHandler) +// Note that constructors are called on every call to Then() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as http.DefaultTransport. +func (c Chain) Then(rt http.RoundTripper) http.RoundTripper { + if rt == nil { + rt = http.DefaultTransport + } + + for i := range c.constructors { + rt = c.constructors[len(c.constructors)-1-i](rt) + } + + return rt +} + +// ThenFunc works identically to Then, but takes +// a RoundTripperFunc instead of a RoundTripper. +// +// The following two statements are equivalent: +// c.Then(http.RoundTripperFunc(fn)) +// c.ThenFunc(fn) +// +// RoundTripperFunc provides all the guarantees of Then. +func (c Chain) ThenFunc(fn RoundTripperFunc) http.RoundTripper { + if fn == nil { + return c.Then(nil) + } + return c.Then(fn) +} + +// Append extends a chain, adding the specified constructors +// as the last ones in the request flow. +// +// Append returns a new chain, leaving the original one untouched. +// +// stdChain := chain.New(m1, m2) +// extChain := stdChain.Append(m3, m4) +// // requests in stdChain go m1 -> m2 +// // requests in extChain go m1 -> m2 -> m3 -> m4 +func (c Chain) Append(constructors ...Constructor) Chain { + newCons := make([]Constructor, 0, len(c.constructors)+len(constructors)) + newCons = append(newCons, c.constructors...) + newCons = append(newCons, constructors...) + + return Chain{newCons} +} + +// Extend extends a chain by adding the specified chain +// as the last one in the request flow. +// +// Extend returns a new chain, leaving the original one untouched. +// +// stdChain := chain.New(m1, m2) +// ext1Chain := chain.New(m3, m4) +// ext2Chain := stdChain.Extend(ext1Chain) +// // requests in stdChain go m1 -> m2 +// // requests in ext1Chain go m3 -> m4 +// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 +// +// Another example: +// aHtmlAfterNosurf := chain.New(m2) +// aHtml := chain.New(m1, func(rt http.RoundTripper) http.RoundTripper { +// csrf := nosurf.New(rt) +// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail)) +// return csrf +// }).Extend(aHtmlAfterNosurf) +// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-roundtripper +// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.constructors...) +} diff --git a/bob/bob_test.go b/bob/bob_test.go new file mode 100644 index 0000000..ce81ff5 --- /dev/null +++ b/bob/bob_test.go @@ -0,0 +1,221 @@ +// Package bob provides a convenient way to chain http round trippers. +package bob + +import ( + "bytes" + "io/ioutil" + "net/http" + "reflect" + "testing" +) + +// A constructor for middleware +// that writes its own "tag" into the request body and does nothing else. +// Useful in checking if a chain is behaving in the right order. +func tagMiddleware(tag string) Constructor { + return func(rt http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + err := appendTag(tag, r) + if err != nil { + return nil, err + } + return rt.RoundTrip(r) + }) + } +} + +func appendTag(tag string, r *http.Request) error { + var newBody []byte + if r.Body == nil { + newBody = []byte(tag) + } else { + body, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + return err + } + newBody = append(body, []byte(tag)...) + } + r.Body = ioutil.NopCloser(bytes.NewBuffer(newBody)) + return nil +} + +// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), +// but the best we can do. +func funcsEqual(f1, f2 interface{}) bool { + val1 := reflect.ValueOf(f1) + val2 := reflect.ValueOf(f2) + return val1.Pointer() == val2.Pointer() +} + +var testApp = RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + appendTag("app\n", r) + return &http.Response{}, nil +}) + +func TestNew(t *testing.T) { + c1 := func(h http.RoundTripper) http.RoundTripper { + return nil + } + + c2 := func(h http.RoundTripper) http.RoundTripper { + return http.DefaultTransport + } + + slice := []Constructor{c1, c2} + + chain := New(slice...) + for k := range slice { + if !funcsEqual(chain.constructors[k], slice[k]) { + t.Error("New does not add constructors correctly") + } + } +} + +func TestThenWorksWithNoMiddleware(t *testing.T) { + if !funcsEqual(New().Then(testApp), testApp) { + t.Error("Then does not work with no middleware") + } +} + +func TestThenTreatsNilAsDefaultTransport(t *testing.T) { + if New().Then(nil) != http.DefaultTransport { + t.Error("Then does not treat nil as DefaultTransport") + } +} + +func TestThenFuncTreatsNilAsDefaultTransport(t *testing.T) { + if New().ThenFunc(nil) != http.DefaultTransport { + t.Error("ThenFunc does not treat nil as DefaultTransport") + } +} + +func TestThenFuncConstructsRoundTripperFunc(t *testing.T) { + fn := RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{}, nil + }) + chained := New().ThenFunc(fn) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.RoundTrip(r) + + if reflect.TypeOf(chained) != reflect.TypeOf((RoundTripperFunc)(nil)) { + t.Error("ThenFunc does not construct RoundTripperFunc") + } +} + +func bodyAsString(r *http.Request) (string, error) { + body, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + return "", err + } + return string(body[:]), nil +} + +func TestThenOrdersRoundTrippersCorrectly(t *testing.T) { + t1 := tagMiddleware("t1\n") + t2 := tagMiddleware("t2\n") + t3 := tagMiddleware("t3\n") + + chained := New(t1, t2, t3).Then(testApp) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.RoundTrip(r) + + body, err := bodyAsString(r) + if err != nil { + t.Fatal(err) + } + if body != "t1\nt2\nt3\napp\n" { + t.Error("Then does not order round trippers correctly") + } +} + +func TestAppendAddsRoundTrippersCorrectly(t *testing.T) { + chain := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) + newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + + if len(chain.constructors) != 2 { + t.Error("chain should have 2 constructors") + } + if len(newChain.constructors) != 4 { + t.Error("newChain should have 4 constructors") + } + + chained := newChain.Then(testApp) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.RoundTrip(r) + + body, err := bodyAsString(r) + if err != nil { + t.Fatal(err) + } + if body != "t1\nt2\nt3\nt4\napp\n" { + t.Error("Append does not add round trippers correctly") + } +} + +func TestAppendRespectsImmutability(t *testing.T) { + chain := New(tagMiddleware("")) + newChain := chain.Append(tagMiddleware("")) + + if &chain.constructors[0] == &newChain.constructors[0] { + t.Error("Apppend does not respect immutability") + } +} + +func TestExtendAddsRoundTrippersCorrectly(t *testing.T) { + chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) + chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + newChain := chain1.Extend(chain2) + + if len(chain1.constructors) != 2 { + t.Error("chain1 should contain 2 constructors") + } + if len(chain2.constructors) != 2 { + t.Error("chain2 should contain 2 constructors") + } + if len(newChain.constructors) != 4 { + t.Error("newChain should contain 4 constructors") + } + + chained := newChain.Then(testApp) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.RoundTrip(r) + + body, err := bodyAsString(r) + if err != nil { + t.Fatal(err) + } + if body != "t1\nt2\nt3\nt4\napp\n" { + t.Error("Extend does not add round trippers in correctly") + } +} + +func TestExtendRespectsImmutability(t *testing.T) { + chain := New(tagMiddleware("")) + newChain := chain.Extend(New(tagMiddleware(""))) + + if &chain.constructors[0] == &newChain.constructors[0] { + t.Error("Extend does not respect immutability") + } +}