diff --git a/chain.go b/chain.go index da0e2b5..7d2393e 100644 --- a/chain.go +++ b/chain.go @@ -1,7 +1,9 @@ // Package alice provides a convenient way to chain http handlers. package alice -import "net/http" +import ( + "net/http" +) // A constructor for a piece of middleware. // Some middleware use this constructor out of the box, @@ -14,6 +16,7 @@ type Constructor func(http.Handler) http.Handler // the same set of constructors in the same order. type Chain struct { constructors []Constructor + endwares []Endware } // New creates a new chain, @@ -21,25 +24,46 @@ type Chain struct { // New serves no other function, // constructors are only called upon a call to Then(). func New(constructors ...Constructor) Chain { - return Chain{append(([]Constructor)(nil), constructors...)} + return Chain{append(([]Constructor)(nil), constructors...), ([]Endware)(nil)} } -// Then chains the middleware and returns the final http.Handler. -// New(m1, m2, m3).Then(h) +// endwareHandler represents a handler that has been modified +// to execute endwares afterwards. This is a helper for Then() +// because if we just wrap it in an anonymous +// http.HandlerFunc(func(w http.ResponseWriter, r *http.Request))) +// there is a stack overflow +type endwareHandler struct { + handler http.Handler + endwares []Endware +} + +// ServeHTTP serves the main endwareHandler's handler as well as +// calling all of the individual endwares afterwards. +func (eh endwareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + eh.handler.ServeHTTP(w, r) + for _, endware := range eh.endwares { + endware.ServeHTTP(w, r) + } +} + +// Then chains the middleware and endwares and returns the final http.Handler. +// New(m1, m2, m3).Finally(e1, e2, e3).Then(h) // is equivalent to: // m1(m2(m3(h))) -// When the request comes in, it will be passed to m1, then m2, then m3 -// and finally, the given handler -// (assuming every middleware calls the following one). +// followed by: +// e1(e2(e3())) +// When the request comes in, it will be passed to m1, then m2, then m3, +// then the given handler (who serves the response), then e1, e2, e3 +// (assuming every middleware/endwares calls the following one). // // A chain can be safely reused by calling Then() several times. -// stdStack := alice.New(ratelimitHandler, csrfHandler) +// stdStack := alice.New(ratelimitHandler, csrfHandler).Finally(loggingHandler) // indexPipe = stdStack.Then(indexHandler) // authPipe = stdStack.Then(authHandler) -// Note that constructors are called on every call to Then() -// and thus several instances of the same middleware will be created +// Note that constructors and endwares are called on every call to Then() +// and thus several instances of the same middleware/endwares will be created // when a chain is reused in this way. -// For proper middleware, this should cause no problems. +// For proper middleware/endwares, this should cause no problems. // // Then() treats nil as http.DefaultServeMux. func (c Chain) Then(h http.Handler) http.Handler { @@ -47,6 +71,10 @@ func (c Chain) Then(h http.Handler) http.Handler { h = http.DefaultServeMux } + if len(c.endwares) > 0 { + h = endwareHandler{h, c.endwares} + } + for i := range c.constructors { h = c.constructors[len(c.constructors)-1-i](h) } @@ -73,6 +101,7 @@ func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler { // as the last ones in the request flow. // // Append returns a new chain, leaving the original one untouched. +// The new chain will have the original chain's endwares. // // stdChain := alice.New(m1, m2) // extChain := stdChain.Append(m3, m4) @@ -83,7 +112,7 @@ func (c Chain) Append(constructors ...Constructor) Chain { newCons = append(newCons, c.constructors...) newCons = append(newCons, constructors...) - return Chain{newCons} + return New(newCons...).AppendEndware(c.endwares...) } // Extend extends a chain by adding the specified chain @@ -92,21 +121,98 @@ func (c Chain) Append(constructors ...Constructor) Chain { // Extend returns a new chain, leaving the original one untouched. // // stdChain := alice.New(m1, m2) -// ext1Chain := alice.New(m3, m4) +// ext1Chain := alice.New(m3, m4).Finally(e1, e2) // ext2Chain := stdChain.Extend(ext1Chain) -// // requests in stdChain go m1 -> m2 -// // requests in ext1Chain go m3 -> m4 -// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 +// // requests in stdChain go m1 -> m2 -> handler +// // requests in ext1Chain go m3 -> m4 -> handler -> e1 -> e2 +// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 -> handler -> e1 -> e2 // // Another example: // aHtmlAfterNosurf := alice.New(m2) +// logRequestChain := aHtmlAfterNosurf.Finally(e1) // aHtml := alice.New(m1, func(h http.Handler) http.Handler { // csrf := nosurf.New(h) -// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail)) +// csrf.SetFailureHandler(logRequestChain.ThenFunc(csrfFail)) // return csrf -// }).Extend(aHtmlAfterNosurf) -// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler -// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail +// }).Extend(logRequestChain) +// // requests to aHtml hitting nosurfs success handler go: +// m1 -> nosurf -> m2 -> target-handler -> e1 +// // requests to aHtml hitting nosurfs failure handler go: +// m1 -> nosurf -> m2 -> csrfFail -> e1 func (c Chain) Extend(chain Chain) Chain { - return c.Append(chain.constructors...) + return c. + Append(chain.constructors...). + AppendEndware(chain.endwares...) +} + +// Endware is functionality executed after a the main handler is called +// and response has been sent to the requester. Like middleware, +// values from the request or response can be accessed. This will not +// let you access values from the request or the response that can no longer be used. +// e.g. re-reading a request body, re-setting the response headers, etc. +type Endware http.Handler + +// Finally creates a new chain with the original chain's +// constructors and endwares, as well as the provided endwares. +// Endwares are executed after both the constructors and +// the Then() handler are called. +func (c Chain) Finally(endwares ...Endware) Chain { + newEnds := make([]Endware, 0, len(c.endwares)+len(endwares)) + newEnds = append(newEnds, c.endwares...) + newEnds = append(newEnds, endwares...) + + newC := New(c.constructors...) + newC.endwares = newEnds + return newC +} + +// FinallyFuncs works identically to Finally, but takes HandlerFuncs +// instead of Endwares. +// +// The following two statements are equivalent: +// c.Finally(http.HandlerFunc(fn1), http.HandlerFunc(fn2)) +// c.FinallyFuncs(fn1, fn2) +// +// FinallyFuncs provides all the guarantees of Finally. +func (c Chain) FinallyFuncs(fns ...func(w http.ResponseWriter, r *http.Request)) Chain { + // convert each http.HandlerFunc into an Endware + endwares := make([]Endware, len(fns)) + for i, fn := range fns { + endwares[i] = http.HandlerFunc(fn) + } + + return c.Finally(endwares...) +} + +// AppendEndware extends a chain, adding the specified endwares +// as the last ones in the request flow. +// +// AppendEndware returns a new chain, leaving the original one untouched. +// The new chain will have the original chain's constructors. +// +// stdChain := alice.New(m1).Finally(e1, e2) +// extChain := stdChain.AppendEndware(e3, e4) +// // requests in stdHandler go m1 -> handler -> e1 -> e2 +// // requests in extHandler go m1 -> handler -> e1 -> e2 -> e3 -> e4 +func (c Chain) AppendEndware(endwares ...Endware) Chain { + return New(c.constructors...).Finally(append(c.endwares, endwares...)...) +} + +// AppendEndwareFuncs works identically to AppendEndware, but takes HandlerFuncs +// instead of Endwares. +// +// The following two statements are equivalent: +// c.AppendEndware(http.HandlerFunc(fn1), http.HandlerFunc(fn2)) +// c.AppendEndwareFuncs(fn1, fn2) +// +// AppendEndwareFuncs provides all the guarantees of AppendEndware. +func (c Chain) AppendEndwareFuncs(fns ...func(w http.ResponseWriter, r *http.Request)) Chain { + // convert each http.HandlerFunc into an Endware + endwares := make([]Endware, len(fns)) + for i, fn := range fns { + endwares[i] = http.HandlerFunc(fn) + } + + return c.AppendEndware(endwares...) + } diff --git a/chain_test.go b/chain_test.go index 6f4316b..53af6b4 100644 --- a/chain_test.go +++ b/chain_test.go @@ -20,6 +20,12 @@ func tagMiddleware(tag string) Constructor { } } +func tagEndware(tag string) Endware { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tag)) + }) +} + // Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), // but the best we can do. func funcsEqual(f1, f2 interface{}) bool { @@ -51,20 +57,40 @@ func TestNew(t *testing.T) { } } +func TestFinally(t *testing.T) { + e1 := tagEndware("e1\n") + e2 := tagEndware("e2\n") + + slice := []Endware{e1, e2} + + chain := New().Finally(slice...) + for k := range slice { + if !funcsEqual(chain.endwares[k], slice[k]) { + t.Error("Finally does not add endwares correctly") + } + } +} + func TestThenWorksWithNoMiddleware(t *testing.T) { if !funcsEqual(New().Then(testApp), testApp) { t.Error("Then does not work with no middleware") } } +func TestThenWorksWithNoEndware(t *testing.T) { + if !funcsEqual(New().Finally().Then(testApp), testApp) { + t.Error("Then does not work with no endware") + } +} + func TestThenTreatsNilAsDefaultServeMux(t *testing.T) { - if New().Then(nil) != http.DefaultServeMux { + if New().Finally().Then(nil) != http.DefaultServeMux { t.Error("Then does not treat nil as DefaultServeMux") } } func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { - if New().ThenFunc(nil) != http.DefaultServeMux { + if New().Finally().ThenFunc(nil) != http.DefaultServeMux { t.Error("ThenFunc does not treat nil as DefaultServeMux") } } @@ -73,7 +99,7 @@ func TestThenFuncConstructsHandlerFunc(t *testing.T) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }) - chained := New().ThenFunc(fn) + chained := New().Finally().ThenFunc(fn) rec := httptest.NewRecorder() chained.ServeHTTP(rec, (*http.Request)(nil)) @@ -87,8 +113,11 @@ func TestThenOrdersHandlersCorrectly(t *testing.T) { t1 := tagMiddleware("t1\n") t2 := tagMiddleware("t2\n") t3 := tagMiddleware("t3\n") + e1 := tagEndware("e1\n") + e2 := tagEndware("e2\n") + e3 := tagEndware("e3\n") - chained := New(t1, t2, t3).Then(testApp) + chained := New(t1, t2, t3).Finally(e1, e2, e3).Then(testApp) w := httptest.NewRecorder() r, err := http.NewRequest("GET", "/", nil) @@ -98,7 +127,7 @@ func TestThenOrdersHandlersCorrectly(t *testing.T) { chained.ServeHTTP(w, r) - if w.Body.String() != "t1\nt2\nt3\napp\n" { + if w.Body.String() != "t1\nt2\nt3\napp\ne1\ne2\ne3\n" { t.Error("Then does not order handlers correctly") } } @@ -110,9 +139,15 @@ func TestAppendAddsHandlersCorrectly(t *testing.T) { if len(chain.constructors) != 2 { t.Error("chain should have 2 constructors") } + if len(chain.endwares) != 0 { + t.Error("chain should have 0 endwares") + } if len(newChain.constructors) != 4 { t.Error("newChain should have 4 constructors") } + if len(newChain.endwares) != 0 { + t.Error("newChain should have 0 endwares") + } chained := newChain.Then(testApp) @@ -129,29 +164,114 @@ func TestAppendAddsHandlersCorrectly(t *testing.T) { } } +func TestAppendEndwareAddsHandlersCorrectly(t *testing.T) { + chain := New(tagMiddleware("t1\n")).Finally(tagEndware("e1\n"), tagEndware("e2\n")) + newChain := chain.AppendEndware(tagEndware("e3\n"), tagEndware("e4\n")) + + if len(chain.constructors) != 1 { + t.Error("chain should have 1 constructor") + } + if len(chain.endwares) != 2 { + t.Error("chain should have 2 endwares") + } + if len(newChain.constructors) != 1 { + t.Error("newChain should have 1 constructor") + } + if len(newChain.endwares) != 4 { + t.Error("newChain should have 4 endwares") + } + + chained := newChain.Then(testApp) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.ServeHTTP(w, r) + + if w.Body.String() != "t1\napp\ne1\ne2\ne3\ne4\n" { + t.Error("AppendEndware does not add handlers correctly") + } +} + func TestAppendRespectsImmutability(t *testing.T) { - chain := New(tagMiddleware("")) + chain := New(tagMiddleware("")).Finally(tagEndware("")) newChain := chain.Append(tagMiddleware("")) if &chain.constructors[0] == &newChain.constructors[0] { - t.Error("Apppend does not respect immutability") + t.Error("Append does not respect constructor immutability") + } + + if &chain.endwares[0] == &newChain.endwares[0] { + t.Error("Append does not respect endware immutability") + } +} + +func TestAppendEndwareRespectsImmutability(t *testing.T) { + chain := New(tagMiddleware("")).Finally(tagEndware("")) + newChain := chain.AppendEndware(tagEndware("")) + + if &chain.constructors[0] == &newChain.constructors[0] { + t.Error("AppendEndware does not respect constructor immutability") + } + + if &chain.endwares[0] == &newChain.endwares[0] { + t.Error("AppendEndware does not respect endware immutability") + } +} + +func TestExtendsRespectsImmutability(t *testing.T) { + chain := New(tagMiddleware("")).Finally(tagEndware("")) + newChain := New(tagMiddleware("")).Finally(tagEndware("")) + newChain = chain.Extend(newChain) + + // chain.constructors[0] should have the same functionality as + // newChain.constructors[1], but check both anyways + if &chain.constructors[0] == &newChain.constructors[0] { + t.Error("Extends does not respect constructor immutability") + } + + if &chain.constructors[0] == &newChain.constructors[1] { + t.Error("Extends does not respect constructor immutability") + } + + if &chain.endwares[0] == &newChain.endwares[0] { + t.Error("Extends does not respect endware immutability") + } + + if &chain.endwares[0] == &newChain.endwares[1] { + t.Error("Extends does not respect endware immutability") } } func TestExtendAddsHandlersCorrectly(t *testing.T) { chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) - chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")). + Finally(tagEndware("e1\n"), tagEndware("e2\n")) newChain := chain1.Extend(chain2) if len(chain1.constructors) != 2 { t.Error("chain1 should contain 2 constructors") } + if len(chain1.endwares) != 0 { + t.Error("chain1 should contain 0 endwares") + } + if len(chain2.constructors) != 2 { t.Error("chain2 should contain 2 constructors") } + if len(chain2.endwares) != 2 { + t.Error("chain2 should contain 2 endwares") + } + if len(newChain.constructors) != 4 { t.Error("newChain should contain 4 constructors") } + if len(newChain.endwares) != 2 { + t.Error("newChain should contain 2 endwares") + } chained := newChain.Then(testApp) @@ -163,16 +283,20 @@ func TestExtendAddsHandlersCorrectly(t *testing.T) { chained.ServeHTTP(w, r) - if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" { + if w.Body.String() != "t1\nt2\nt3\nt4\napp\ne1\ne2\n" { t.Error("Extend does not add handlers in correctly") } } func TestExtendRespectsImmutability(t *testing.T) { - chain := New(tagMiddleware("")) - newChain := chain.Extend(New(tagMiddleware(""))) + chain := New(tagMiddleware("")).Finally(tagEndware("")) + newChain := chain.Extend(New()) if &chain.constructors[0] == &newChain.constructors[0] { - t.Error("Extend does not respect immutability") + t.Error("Extend does not respect immutability for constructors") + } + + if &chain.endwares[0] == &newChain.endwares[0] { + t.Error("Extend does not respect immutability for endwares") } }