diff --git a/chain.go b/chain.go index f00f12a..77adac0 100644 --- a/chain.go +++ b/chain.go @@ -69,13 +69,12 @@ func (c Chain) Then(h http.Handler) http.Handler { // c.ThenFunc(fn) // // ThenFunc provides all the guarantees of Then. -func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler { - // This nil check cannot be removed due to the "nil is not nil" common mistake in Go. - // Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil +func (c Chain) ThenFunc(fn func(http.ResponseWriter, *http.Request)) http.Handler { + // We want to preserve the special behavior similar to Then(nil) if fn == nil { return c.Then(nil) } - return c.Then(fn) + return c.Then(http.HandlerFunc(fn)) } // Append extends a chain, adding the specified constructors diff --git a/chain_test.go b/chain_test.go index c486553..392e0da 100644 --- a/chain_test.go +++ b/chain_test.go @@ -67,6 +67,9 @@ func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { if New().ThenFunc(nil) != http.DefaultServeMux { t.Error("ThenFunc does not treat nil as DefaultServeMux") } + if New().ThenFunc(http.HandlerFunc(nil)) != http.DefaultServeMux { + t.Error("ThenFunc does not treat nil as DefaultServeMux") + } } func TestThenFuncConstructsHandlerFunc(t *testing.T) {