diff --git a/middleware/gapchi/README.md b/middleware/gapchi/README.md new file mode 100644 index 0000000..8ee986e --- /dev/null +++ b/middleware/gapchi/README.md @@ -0,0 +1,20 @@ +# gaphttp + +handler and client middleware, and other tools + +### CLIENT ROUND TRIPPER + +```go + httpClient := &http.Client{ + Transport: NewMiddlewareRoundTrip{http.DefaultTransport}, + } +``` + +### HTTP MIDDLEWARE + +```go + router := chi.NewRouter() + router.Use(NewMiddleware()) + + router.Post("/do", func(writer http.ResponseWriter, request *http.Request) {}) +``` \ No newline at end of file diff --git a/middleware/gapchi/chi_path_cleaner.go b/middleware/gapchi/chi_path_cleaner.go new file mode 100644 index 0000000..3129f0e --- /dev/null +++ b/middleware/gapchi/chi_path_cleaner.go @@ -0,0 +1,19 @@ +package gaphttp + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +func RemoveChiPathParam(request *http.Request) string { + path := request.URL.Path + if rctx := chi.RouteContext(request.Context()); rctx != nil { + for i, v := range rctx.URLParams.Values { + path = strings.Replace(path, v, rctx.URLParams.Keys[i], 1) + } + } + + return path +} diff --git a/middleware/gapchi/go.mod b/middleware/gapchi/go.mod new file mode 100644 index 0000000..9b8ac8b --- /dev/null +++ b/middleware/gapchi/go.mod @@ -0,0 +1,20 @@ +module github.com/tel-io/instrumentation/middleware/gapchi + +go 1.19 + +require ( + github.com/go-chi/chi/v5 v5.0.7 + github.com/stretchr/testify v1.8.1 + go.opentelemetry.io/otel v1.11.1 + go.opentelemetry.io/otel/metric v0.33.0 + go.opentelemetry.io/otel/trace v1.11.1 + go.uber.org/zap v1.23.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/middleware/gapchi/go.sum b/middleware/gapchi/go.sum new file mode 100644 index 0000000..93e0b30 --- /dev/null +++ b/middleware/gapchi/go.sum @@ -0,0 +1,36 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= +github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4= +go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE= +go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E= +go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI= +go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ= +go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= +go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware/gapchi/handler_basic.go b/middleware/gapchi/handler_basic.go new file mode 100644 index 0000000..c6bf836 --- /dev/null +++ b/middleware/gapchi/handler_basic.go @@ -0,0 +1,56 @@ +package gaphttp + +import ( + "context" + "encoding/json" + "net/http" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +type HandlerBasic struct { + ServiceName string + Logger *zap.Logger +} + +func NewHandlerBasic(svcName string, log *zap.Logger) *HandlerBasic { + return &HandlerBasic{ + ServiceName: svcName, + Logger: log, + } +} + +func (h *HandlerBasic) OK(writer http.ResponseWriter, request *http.Request) { + h.Write(request.Context(), writer, http.StatusOK, h.ServiceName) +} + +func (h *HandlerBasic) NotFound(writer http.ResponseWriter, request *http.Request) { + h.Write(request.Context(), writer, http.StatusNotFound, "invalid path") +} + +func (h *HandlerBasic) Write(ctx context.Context, writer http.ResponseWriter, statusCode int, response interface{}) { + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(statusCode) + + if out, ok := response.([]byte); ok { + if _, err := writer.Write(out); err != nil { + h.Log(ctx).Error("write response", zap.Error(err), zap.ByteString("response", out)) + } + + return + } + + if err := json.NewEncoder(writer).Encode(response); err != nil { + h.Log(ctx).Error("json encoder, write response", zap.Error(err), zap.Any("response", response)) + } +} + +func (h *HandlerBasic) Log(ctx context.Context) *zap.Logger { + spCtx := trace.SpanContextFromContext(ctx) + + return h.Logger.With( + zap.String("traceID", spCtx.TraceID().String()), + zap.String("spanID", spCtx.SpanID().String()), + ) +} diff --git a/middleware/gapchi/middleware.go b/middleware/gapchi/middleware.go new file mode 100644 index 0000000..754f680 --- /dev/null +++ b/middleware/gapchi/middleware.go @@ -0,0 +1,154 @@ +package gaphttp + +import ( + "bytes" + "errors" + "net/http" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +type Middleware func(next http.Handler) http.Handler + +type MiddlewareOptions struct { + ServiceName string + EnabledLogger bool + EnabledTracer bool + EnabledRecover bool + EnabledMetrics bool +} + +const ( + builderNameMiddlewareTrace = "trace.middleware" + builderNameMiddlewareLogger = "logger.middleware" + builderNameMiddlewareRecovery = "recovery.middleware" + builderNameMiddlewareMetrics = "metrics.middleware" +) + +type MiddlewareBuilder struct { + opt *MiddlewareOptions + logger *zap.Logger + tracer trace.Tracer + metrics Metrics + middlewares map[string]func() func(next http.Handler) http.Handler +} + +func NewMiddlewareBuilder( + opt *MiddlewareOptions, + logger *zap.Logger, + tracer trace.Tracer, + metrics Metrics, +) *MiddlewareBuilder { + return &MiddlewareBuilder{ + opt: opt, + logger: logger, + tracer: tracer, + metrics: metrics, + middlewares: make(map[string]func() func(next http.Handler) http.Handler), + } + +} + +func (b *MiddlewareBuilder) AddTrace() *MiddlewareBuilder { + b.middlewares[builderNameMiddlewareTrace] = func() func(next http.Handler) http.Handler { + return NewMiddlewareTracer(b.tracer, b.opt) + } + + return b +} + +func (b *MiddlewareBuilder) AddLogger() *MiddlewareBuilder { + b.middlewares[builderNameMiddlewareLogger] = func() func(next http.Handler) http.Handler { + return NewMiddlewareLogger(b.logger, b.opt) + } + + return b +} + +func (b *MiddlewareBuilder) AddMiddlewareRecover() *MiddlewareBuilder { + b.middlewares[builderNameMiddlewareRecovery] = func() func(next http.Handler) http.Handler { + return NewMiddlewareRecovery(b.logger, b.opt) + } + + return b +} + +func (b *MiddlewareBuilder) AddMiddlewareMeter() *MiddlewareBuilder { + b.middlewares[builderNameMiddlewareMetrics] = func() func(next http.Handler) http.Handler { + return NewMiddlewareMetrics(b.metrics, b.opt) + } + + return b +} + +// Build - build middleware and create correct position in stack +func (b *MiddlewareBuilder) Build() ([]func(next http.Handler) http.Handler, error) { + if b.opt == nil { + return nil, errors.New("options cannot be blank") + } + + out := make([]func(next http.Handler) http.Handler, 0, len(b.middlewares)) + + recMidForMid, ok := b.middlewares[builderNameMiddlewareRecovery] + if ok { + out = append(out, recMidForMid()) + } + + metrMid, ok := b.middlewares[builderNameMiddlewareMetrics] + if ok { + out = append(out, metrMid()) + } + + traceMid, ok := b.middlewares[builderNameMiddlewareTrace] + if ok { + out = append(out, traceMid()) + } + + logMid, ok := b.middlewares[builderNameMiddlewareLogger] + if ok { + out = append(out, logMid()) + } + + recMidSrv, ok := b.middlewares[builderNameMiddlewareRecovery] + if ok { + out = append(out, recMidSrv()) + } + + return out, nil +} + +// WrapWriter supported butch loading +type WrapWriter struct { + status int + body *bytes.Buffer + inner http.ResponseWriter +} + +func NewWrapWriter(inner http.ResponseWriter) *WrapWriter { + return &WrapWriter{body: bytes.NewBuffer([]byte{}), inner: inner} +} + +func (w *WrapWriter) Header() http.Header { + return w.inner.Header() +} + +func (w *WrapWriter) Write(i []byte) (int, error) { + w.body.Write(i) + + return w.inner.Write(i) +} + +func (w *WrapWriter) WriteHeader(statusCode int) { + w.status = statusCode + + w.inner.WriteHeader(statusCode) +} + +func (w *WrapWriter) Status() int { + return w.status +} + +func (w *WrapWriter) Body() []byte { + return w.body.Bytes() +} diff --git a/middleware/gapchi/middleware_logger.go b/middleware/gapchi/middleware_logger.go new file mode 100644 index 0000000..b7f3c39 --- /dev/null +++ b/middleware/gapchi/middleware_logger.go @@ -0,0 +1,67 @@ +package gaphttp + +import ( + "bytes" + "io" + "net/http" + + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func NewMiddlewareLogger(logger *zap.Logger, option *MiddlewareOptions) Middleware { + // logger enabled + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + // set trace id and span id to logger + span := trace.SpanFromContext(request.Context()) + + log := logger.With( + zap.String("traceID", span.SpanContext().TraceID().String()), + zap.String("spanID", span.SpanContext().SpanID().String()), + ) + + log = log.With( + zap.String("http.request.method", request.Method), + zap.Any("http.request.header", request.Header), + zap.Any("http.request.ip", request.RemoteAddr), + zap.Any("http.request.path", request.URL.Path), + zap.Any("http.request.query", request.URL.Query()), + zap.Any("http.request.cookie", request.Cookies()), + ) + + // set request body to logger if exist + if request.Body != nil { + reqBody, err := io.ReadAll(request.Body) + if err != nil { + logger.Error("http.request.body.read", zap.Error(err)) + } + + log = log.With(zap.ByteString("http.request.body", reqBody)) + + request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset + } + + wrapWriter := NewWrapWriter(writer) + + log.Debug("handle started!") + + next.ServeHTTP(wrapWriter, request) + + if wrapWriter.status == http.StatusOK { + log.Debug( + "handle stopped!", + zap.Int("http.response.status", wrapWriter.Status()), + zap.ByteString("http.response.body", wrapWriter.Body()), + ) + } else { + log.Error("handle stopped!", + zap.Int("http.response.status", wrapWriter.Status()), + zap.ByteString("http.response.body", wrapWriter.Body()), + ) + } + }, + ) + } +} diff --git a/middleware/gapchi/middleware_logger_test.go b/middleware/gapchi/middleware_logger_test.go new file mode 100644 index 0000000..abd35fc --- /dev/null +++ b/middleware/gapchi/middleware_logger_test.go @@ -0,0 +1,82 @@ +package gaphttp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +const testURL = "/test" + +type testCase struct { + getReq func(url string) *http.Request + postReq func(url string) *http.Request + + get http.HandlerFunc + post http.HandlerFunc +} + +func TestMiddlewareLogger(t *testing.T) { + tc := testCase{ + getReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?%s=%s", url, "name", "test"), nil) + require.Nil(t, err) + + return req + }, + postReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer([]byte(`{"id":"11"}`))) + require.Nil(t, err) + + return req + }, + get: func(writer http.ResponseWriter, request *http.Request) { + require.EqualValues(t, "test", request.URL.Query().Get("name")) + writer.WriteHeader(http.StatusOK) + writer.Write([]byte(`"test": "test"`)) + }, + post: func(writer http.ResponseWriter, request *http.Request) { + var i interface{} + err := json.NewDecoder(request.Body).Decode(&i) + require.Nil(t, err) + + writer.WriteHeader(http.StatusOK) + writer.Write([]byte(`"test": "test"`)) + }, + } + + router := chi.NewRouter() + router.Use(middleware.Recoverer) + router.Use(NewMiddlewareLogger(zap.NewExample(), &MiddlewareOptions{EnabledLogger: true})) + router.Get(testURL, tc.get) + router.Post(testURL, tc.post) + + server := httptest.NewServer(router) + + cli := http.Client{Transport: NewMiddlewareRoundTrip(http.DefaultTransport, true, zap.NewExample())} + go func() { + reqs := []*http.Request{ + tc.getReq(fmt.Sprintf("%s%s", server.URL, testURL)), + tc.postReq(fmt.Sprintf("%s%s", server.URL, testURL)), + } + + for _, req := range reqs { + resp, err := cli.Do(req) + require.Nil(t, err) + require.EqualValues(t, resp.StatusCode, http.StatusOK) + b, err := io.ReadAll(resp.Body) + require.Nil(t, err) + require.EqualValues(t, string(b), `"test": "test"`) + } + }() + +} diff --git a/middleware/gapchi/middleware_metric.go b/middleware/gapchi/middleware_metric.go new file mode 100644 index 0000000..5ccae46 --- /dev/null +++ b/middleware/gapchi/middleware_metric.go @@ -0,0 +1,146 @@ +package gaphttp + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/instrument/syncfloat64" + "go.opentelemetry.io/otel/metric/instrument/syncint64" + semconv "go.opentelemetry.io/otel/semconv/v1.12.0" +) + +const ( + RequestCount = "http.server.request_count" // Incoming request count total + RequestContentLength = "http.server.request_content_length" // Incoming request bytes total + ResponseContentLength = "http.server.response_content_length" // Incoming response bytes total + ServerLatency = "http.server.duration" // Incoming end to end duration, microseconds +) + +type Metrics interface { + IncRequestCount(ctx context.Context, attrs ...attribute.KeyValue) + AddRequestContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) + AddResponseContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) + AddHandleDuration(ctx context.Context, dur float64, attrs ...attribute.KeyValue) +} + +func NewMiddlewareMetrics(metrics Metrics, option *MiddlewareOptions) Middleware { + // metrics disabled + if !option.EnabledMetrics { + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + next.ServeHTTP(writer, request) + }, + ) + } + } + + // metrics enabled + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + requestStartTime := time.Now().UTC() + + wrappedWriter := NewWrapWriter(writer) + + next.ServeHTTP(wrappedWriter, request) + + labeler := new(Labeler) + labeler.Add(attribute.String("http.url.path", RemoveChiPathParam(request))) + labeler.Add(attribute.Int("http.status", wrappedWriter.Status())) + + attributes := append( + labeler.Get(), + semconv.HTTPServerMetricAttributesFromHTTPRequest(option.ServiceName, request)..., + ) + + elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond) + + metrics.AddHandleDuration(request.Context(), elapsedTime, attributes...) + metrics.IncRequestCount(request.Context(), attributes...) + metrics.AddRequestContentLength(request.Context(), request.ContentLength, attributes...) + metrics.AddResponseContentLength(request.Context(), int64(len(wrappedWriter.Body())), attributes...) + }, + ) + } +} + +func RegisterBasicMetrics(meter metric.Meter) (*BasicMetrics, error) { + rc, err := meter.SyncInt64().Counter(RequestCount) + if err != nil { + return nil, fmt.Errorf("meter, syncInt64, counter: %w", err) + } + + reqCl, err := meter.SyncInt64().Counter(RequestContentLength) + if err != nil { + return nil, fmt.Errorf("meter, syncInt64, counter: %w", err) + } + + resCl, err := meter.SyncInt64().Counter(ResponseContentLength) + if err != nil { + return nil, fmt.Errorf("meter, syncInt64, counter: %w", err) + } + + sl, err := meter.SyncFloat64().Histogram(ServerLatency) + if err != nil { + return nil, fmt.Errorf("meter, syncFloat64, histogram: %w", err) + } + + return &BasicMetrics{ + requestCount: rc, + requestContentLength: reqCl, + responseContentLength: resCl, + handleDuration: sl, + }, nil +} + +type BasicMetrics struct { + requestCount syncint64.Counter + requestContentLength syncint64.Counter + responseContentLength syncint64.Counter + handleDuration syncfloat64.Histogram +} + +func (m *BasicMetrics) IncRequestCount(ctx context.Context, attrs ...attribute.KeyValue) { + m.requestCount.Add(ctx, 1, attrs...) +} + +func (m *BasicMetrics) AddRequestContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) { + m.requestContentLength.Add(ctx, len, attrs...) +} + +func (m *BasicMetrics) AddResponseContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) { + m.responseContentLength.Add(ctx, len, attrs...) +} + +func (m *BasicMetrics) AddHandleDuration(ctx context.Context, dur float64, attrs ...attribute.KeyValue) { + m.handleDuration.Record(ctx, dur, attrs...) +} + +// Labeler is used to allow instrumented HTTP handlers to add custom attributes to +// the metrics recorded by the net/http instrumentation. +type Labeler struct { + mu sync.Mutex + attributes []attribute.KeyValue +} + +// Add attributes to a Labeler. +func (l *Labeler) Add(ls ...attribute.KeyValue) { + l.mu.Lock() + defer l.mu.Unlock() + l.attributes = append(l.attributes, ls...) +} + +// Get returns a copy of the attributes added to the Labeler. +func (l *Labeler) Get() []attribute.KeyValue { + l.mu.Lock() + defer l.mu.Unlock() + ret := make([]attribute.KeyValue, len(l.attributes)) + copy(ret, l.attributes) + return ret +} diff --git a/middleware/gapchi/middleware_metrics_test.go b/middleware/gapchi/middleware_metrics_test.go new file mode 100644 index 0000000..94a43b4 --- /dev/null +++ b/middleware/gapchi/middleware_metrics_test.go @@ -0,0 +1,106 @@ +package gaphttp + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" +) + +func TestMiddlewareMetrics(t *testing.T) { + tc := testCase{ + postReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer([]byte(`{"id":"11"}`))) + require.Nil(t, err) + + return req + }, + post: func(writer http.ResponseWriter, request *http.Request) {}, + } + + router := chi.NewRouter() + metrics := newMetricsTest() + router.Use(NewMiddlewareMetrics(metrics, &MiddlewareOptions{ServiceName: "test", EnabledMetrics: true})) + router.Post(testURL, tc.post) + + server := httptest.NewServer(router) + + reqs := []*http.Request{ + tc.postReq(fmt.Sprintf("%s%s", server.URL, testURL)), + } + + // Attention! NewMiddlewareRoundTrip + cli := http.Client{Transport: NewMiddlewareRoundTrip(http.DefaultTransport, true, zap.NewExample())} + for _, req := range reqs { + _, err := cli.Do(req) + require.Nil(t, err) + } + + require.Len(t, metrics.data, 4) +} + +type metricsTest struct { + locker *sync.Mutex + data map[string]struct { + Value interface{} + Attrs []attribute.KeyValue + } +} + +func newMetricsTest() *metricsTest { + return &metricsTest{ + locker: &sync.Mutex{}, + data: make(map[string]struct { + Value interface{} + Attrs []attribute.KeyValue + }), + } +} + +func (m *metricsTest) IncRequestCount(ctx context.Context, attrs ...attribute.KeyValue) { + m.locker.Lock() + defer m.locker.Unlock() + + m.data[RequestCount] = struct { + Value interface{} + Attrs []attribute.KeyValue + }{Value: 1, Attrs: attrs} +} + +func (m *metricsTest) AddRequestContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) { + m.locker.Lock() + defer m.locker.Unlock() + + m.data[RequestContentLength] = struct { + Value interface{} + Attrs []attribute.KeyValue + }{Value: len, Attrs: attrs} +} + +func (m *metricsTest) AddResponseContentLength(ctx context.Context, len int64, attrs ...attribute.KeyValue) { + m.locker.Lock() + defer m.locker.Unlock() + + m.data[ResponseContentLength] = struct { + Value interface{} + Attrs []attribute.KeyValue + }{Value: len, Attrs: attrs} +} + +func (m *metricsTest) AddHandleDuration(ctx context.Context, dur float64, attrs ...attribute.KeyValue) { + m.locker.Lock() + defer m.locker.Unlock() + + m.data[ServerLatency] = struct { + Value interface{} + Attrs []attribute.KeyValue + }{Value: dur, Attrs: attrs} +} diff --git a/middleware/gapchi/middleware_recovery.go b/middleware/gapchi/middleware_recovery.go new file mode 100644 index 0000000..8585f23 --- /dev/null +++ b/middleware/gapchi/middleware_recovery.go @@ -0,0 +1,37 @@ +package gaphttp + +import ( + "net/http" + + "go.uber.org/zap" +) + +func NewMiddlewareRecovery(logger *zap.Logger, option *MiddlewareOptions) Middleware { + // recover disabled + if !option.EnabledRecover { + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + next.ServeHTTP(writer, request) + }, + ) + } + } + + // recover enabled + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + defer func() { + if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler { + logger.Error("middleware catch panic", zap.Any("panic.body", rvr)) + + writer.WriteHeader(http.StatusInternalServerError) + } + }() + + next.ServeHTTP(writer, request) + }, + ) + } +} diff --git a/middleware/gapchi/middleware_recovery_test.go b/middleware/gapchi/middleware_recovery_test.go new file mode 100644 index 0000000..f783621 --- /dev/null +++ b/middleware/gapchi/middleware_recovery_test.go @@ -0,0 +1,70 @@ +package gaphttp + +import ( + "bytes" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestMiddlewareRecover(t *testing.T) { + tc := testCase{ + getReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?%s=%s", url, "name", "test"), nil) + require.Nil(t, err) + + req = req.WithContext(trace.ContextWithRemoteSpanContext( + req.Context(), trace.NewSpanContext( + trace.SpanContextConfig{ + TraceID: trace.TraceID{0x1}, + SpanID: trace.SpanID{}, + TraceFlags: 0, + TraceState: trace.TraceState{}, + Remote: false, + }, + ), + )) + + return req + }, + postReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer([]byte(`{"id":"11"}`))) + require.Nil(t, err) + + return req + }, + + get: func(writer http.ResponseWriter, request *http.Request) { + panic("test0") + }, + post: func(writer http.ResponseWriter, request *http.Request) { + panic(errors.New("test")) + }, + } + + router := chi.NewRouter() + router.Use(NewMiddlewareRecovery(zap.NewExample(), &MiddlewareOptions{EnabledRecover: true})) + router.Get(testURL, tc.get) + router.Post(testURL, tc.post) + + server := httptest.NewServer(router) + + reqs := []*http.Request{ + tc.getReq(fmt.Sprintf("%s%s", server.URL, testURL)), + tc.postReq(fmt.Sprintf("%s%s", server.URL, testURL)), + } + + // Attention! NewMiddlewareRoundTrip + cli := http.DefaultClient + for _, req := range reqs { + _, err := cli.Do(req) + require.Nil(t, err) + } +} diff --git a/middleware/gapchi/middleware_round_trip.go b/middleware/gapchi/middleware_round_trip.go new file mode 100644 index 0000000..0d89ca2 --- /dev/null +++ b/middleware/gapchi/middleware_round_trip.go @@ -0,0 +1,110 @@ +package gaphttp + +import ( + "bytes" + "fmt" + "io" + "net/http" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +type MiddlewareRoundTrip struct { + inner http.RoundTripper + debug bool + log *zap.Logger +} + +func NewMiddlewareRoundTrip( + inner http.RoundTripper, + debug bool, + log *zap.Logger, +) *MiddlewareRoundTrip { + return &MiddlewareRoundTrip{ + inner: inner, + debug: debug, + log: log, + } +} + +func (rt *MiddlewareRoundTrip) RoundTrip(req *http.Request) (*http.Response, error) { + span := trace.SpanFromContext(req.Context()) + span.AddEvent( + fmt.Sprintf("http.client.request.started: %s", req.URL.Path), + trace.WithAttributes(attribute.String("http.client.method", req.Method)), + ) + + req.Header.Add(RequestHeaderTraceID, span.SpanContext().TraceID().String()) + + rt.LogRequest(span, req) + + res, err := rt.inner.RoundTrip(req) + if err != nil { + rt.log.Error("round trip", zap.Error(err)) + + return res, err + } + + span.AddEvent( + fmt.Sprintf("http.client.request.ended: %s", req.URL.Path), + trace.WithAttributes(attribute.Int("http.client.status", res.StatusCode)), + ) + + rt.LogResponse(span, res) + + return res, err +} + +func (rt *MiddlewareRoundTrip) LogResponse(span trace.Span, response *http.Response) { + if !rt.debug { + return + } + + log := rt.log.With( + zap.String("traceID", span.SpanContext().TraceID().String()), + zap.String("spanID", span.SpanContext().SpanID().String()), + zap.Int("http.client.response.status", response.StatusCode), + ) + + if response.Body != nil { + respBody, err := io.ReadAll(response.Body) + if err != nil { + log.Error("http.client.response.body.read", zap.Error(err)) + } + + log = log.With(zap.ByteString("http.client.response.body", respBody)) + + response.Body = io.NopCloser(bytes.NewBuffer(respBody)) // Reset + } + + log.Debug("http.client.request.stopped!") +} + +func (rt *MiddlewareRoundTrip) LogRequest(span trace.Span, request *http.Request) { + if !rt.debug { + return + } + + log := rt.log.With( + zap.String("traceID", span.SpanContext().TraceID().String()), + zap.String("spanID", span.SpanContext().SpanID().String()), + zap.String("http.client.request.method", request.Method), + zap.Any("http.client.request.header", request.Header), + zap.String("http.client.request.url", request.URL.String()), + ) + + if request.Body != nil { + reqBody, err := io.ReadAll(request.Body) + if err != nil { + log.Error("http.client.request.body.read", zap.Error(err)) + } + + log = log.With(zap.ByteString("http.request.body", reqBody)) + + request.Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset + } + + log.Debug("http.client.request.started!") +} diff --git a/middleware/gapchi/middleware_trace.go b/middleware/gapchi/middleware_trace.go new file mode 100644 index 0000000..67a19cb --- /dev/null +++ b/middleware/gapchi/middleware_trace.go @@ -0,0 +1,94 @@ +package gaphttp + +import ( + "fmt" + "net/http" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +const RequestHeaderTraceID = "x-trace-id" + +type Header interface { + Set(key string, value string) + Get(key string) string +} + +func ExtractRemoteTraceID(header Header) (trace.TraceID, bool) { + traceIdAsString := header.Get(RequestHeaderTraceID) + if traceIdAsString == "" { + return trace.TraceID{}, false + } + + traceId, err := trace.TraceIDFromHex(traceIdAsString) + if err != nil { + return trace.TraceID{}, false + } + + return traceId, true +} + +func NewMiddlewareTracer(tracer trace.Tracer, option *MiddlewareOptions) Middleware { + // tracer disabled + if !option.EnabledTracer { + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + next.ServeHTTP(writer, request) + }, + ) + } + } + + // tracer enabled + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(writer http.ResponseWriter, request *http.Request) { + reqCtx := request.Context() + rTraceID, isRemote := ExtractRemoteTraceID(request.Header) + if isRemote { + remoteSpanCtx := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: rTraceID, + SpanID: trace.SpanID{}, + TraceFlags: 0, + TraceState: trace.TraceState{}, + Remote: isRemote, + }) + + reqCtx = trace.ContextWithRemoteSpanContext(request.Context(), remoteSpanCtx) + } + + path := RemoveChiPathParam(request) + ctxWithSpan, span := tracer.Start(reqCtx, makeSpanName(path)) + + span.SetAttributes( + attribute.Bool("trace.id.remoted", isRemote), + attribute.String("http.method", request.Method), + attribute.String("http.user_agent", request.UserAgent()), + attribute.String("http.client_ip", request.RemoteAddr), + attribute.Int64("http.request_content_length", request.ContentLength), + attribute.String("http.target", request.URL.Path), + ) + + request = request.WithContext(ctxWithSpan) + wrappedWriter := NewWrapWriter(writer) + + next.ServeHTTP(wrappedWriter, request) + + status := codes.Ok + if wrappedWriter.status != http.StatusOK { + status = codes.Error + } + + span.SetStatus(status, fmt.Sprintf("http.status: %d", wrappedWriter.Status())) + span.End() + }, + ) + } +} + +func makeSpanName(path string) string { + return fmt.Sprintf("http.handler.%s", path) +} diff --git a/middleware/gapchi/middleware_trace_test.go b/middleware/gapchi/middleware_trace_test.go new file mode 100644 index 0000000..2e6045e --- /dev/null +++ b/middleware/gapchi/middleware_trace_test.go @@ -0,0 +1,79 @@ +package gaphttp + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestMiddlewareTrace(t *testing.T) { + tc := testCase{ + getReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?%s=%s", url, "name", "test"), nil) + require.Nil(t, err) + + req = req.WithContext(trace.ContextWithRemoteSpanContext( + req.Context(), trace.NewSpanContext( + trace.SpanContextConfig{ + TraceID: trace.TraceID{0x1}, + SpanID: trace.SpanID{}, + TraceFlags: 0, + TraceState: trace.TraceState{}, + Remote: false, + }, + ), + )) + + return req + }, + postReq: func(url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer([]byte(`{"id":"11"}`))) + require.Nil(t, err) + + return req + }, + + get: func(writer http.ResponseWriter, request *http.Request) { + fmt.Println( + trace.SpanFromContext(request.Context()).SpanContext().TraceID().String(), + trace.SpanFromContext(request.Context()).SpanContext().SpanID().String(), + ) + }, + post: func(writer http.ResponseWriter, request *http.Request) { + fmt.Println( + trace.SpanFromContext(request.Context()).SpanContext().TraceID().String(), + trace.SpanFromContext(request.Context()).SpanContext().SpanID().String(), + ) + }, + } + + router := chi.NewRouter() + router.Use(middleware.Recoverer) + tracer := trace.NewNoopTracerProvider().Tracer("test") + router.Use(NewMiddlewareTracer(tracer, &MiddlewareOptions{EnabledTracer: true})) + router.Get(testURL, tc.get) + router.Post(testURL, tc.post) + + server := httptest.NewServer(router) + + reqs := []*http.Request{ + tc.getReq(fmt.Sprintf("%s%s", server.URL, testURL)), + tc.postReq(fmt.Sprintf("%s%s", server.URL, testURL)), + } + + // Attention! NewMiddlewareRoundTrip + cli := http.Client{Transport: NewMiddlewareRoundTrip(http.DefaultTransport, true, zap.NewExample())} + for _, req := range reqs { + _, err := cli.Do(req) + require.Nil(t, err) + } + +}