From a70f1aeb7bc542fb105d5aff49740694ce6a1cb2 Mon Sep 17 00:00:00 2001 From: vshulcz Date: Fri, 16 Jan 2026 14:12:23 +0300 Subject: [PATCH 1/2] feat: support for asymmetric encryption to the agent and server --- README.md | 2 + cmd/agent/main.go | 28 +- cmd/server/init.go | 12 +- cmd/server/main.go | 160 ++++++++---- .../{services => application}/agent/agent.go | 5 +- .../agent/agent_benchmark_test.go | 3 +- .../agent/agent_test.go | 9 +- internal/application/agent/config.go | 10 + .../{services => application}/agent/sender.go | 0 .../agent/sender_test.go | 0 .../metrics/metrics.go | 31 +-- .../metrics/metrics_benchmark_test.go | 4 +- .../metrics/metrics_test.go | 30 +-- internal/domain/audit_event.go | 8 + internal/infra/audit/fanout.go | 69 +++++ internal/infra/audit/fanout_test.go | 105 ++++++++ .../{adapters => infra}/audit/file/file.go | 0 .../{adapters => infra}/audit/file/writer.go | 6 +- .../audit/file/writer_test.go | 12 +- .../audit/remote/client.go | 6 +- .../audit/remote/client_test.go | 16 +- .../audit/remote/remote.go | 0 .../collector/runtime/keys.go | 0 .../collector/runtime/runtime.go | 0 .../collector/runtime/runtime_test.go | 0 .../collector/runtime/state.go | 0 internal/{ => infra}/config/agent.go | 5 + internal/{ => infra}/config/agent_test.go | 2 +- internal/{ => infra}/config/config.go | 0 internal/{ => infra}/config/helpers.go | 41 ++- internal/{ => infra}/config/helpers_test.go | 2 +- internal/{ => infra}/config/server.go | 5 + internal/{ => infra}/config/server_test.go | 2 +- internal/infra/crypto/rsaenvelope/envelope.go | 218 ++++++++++++++++ .../infra/crypto/rsaenvelope/envelope_test.go | 80 ++++++ .../http/ginserver/example_test.go | 6 +- .../http/ginserver/ginserver.go | 0 .../http/ginserver/handler.go | 15 +- .../http/ginserver/handler_benchmark_test.go | 2 +- .../http/ginserver/handler_test.go | 6 +- .../ginserver/middlewares/crypto_decrypt.go | 44 ++++ .../middlewares/crypto_decrypt_test.go | 54 ++++ .../http/ginserver/middlewares/gin_gzip.go | 0 .../http/ginserver/middlewares/hashsha256.go | 12 +- .../http/ginserver/middlewares/middlewares.go | 0 .../ginserver/middlewares/middlewares_test.go | 244 ++++++++++++++++++ .../http/ginserver/middlewares/zap_logger.go | 0 .../http/ginserver/router.go | 0 .../persistence/file/file.go | 0 .../persistence/file/file_test.go | 2 +- .../publisher/httpjson/client.go | 35 ++- .../httpjson/client_benchmark_test.go | 2 +- .../publisher/httpjson/client_test.go | 164 ++++++++---- internal/infra/publisher/httpjson/hash.go | 11 + .../publisher/httpjson/httpjson.go | 0 .../repository/memory/memory.go | 0 .../repository/memory/memory_test.go | 0 .../repository/postgres/migrate.go | 0 .../repository/postgres/migrate_test.go | 0 .../postgres/migrations/0001_init.sql | 0 .../repository/postgres/postgres.go | 16 +- .../repository/postgres/postgres_test.go | 50 ++-- internal/{misc => infra/retry}/retry.go | 2 +- internal/infra/retry/retry_test.go | 88 +++++++ internal/misc/env.go | 53 ---- internal/misc/env_test.go | 107 -------- internal/misc/hash.go | 12 - internal/misc/hash_test.go | 53 ---- internal/misc/misc.go | 2 - internal/misc/pool.go | 42 --- internal/misc/pool_test.go | 70 ----- internal/misc/retry_test.go | 130 ---------- internal/ports/audit.go | 12 + internal/ports/crypto.go | 15 ++ internal/services/audit/audit.go | 2 - internal/services/audit/context.go | 21 -- internal/services/audit/context_test.go | 18 -- internal/services/audit/event.go | 8 - internal/services/audit/subject.go | 20 -- internal/services/audit/subject_test.go | 57 ---- profiles/README.md | 14 +- 81 files changed, 1402 insertions(+), 858 deletions(-) rename internal/{services => application}/agent/agent.go (94%) rename internal/{services => application}/agent/agent_benchmark_test.go (94%) rename internal/{services => application}/agent/agent_test.go (96%) create mode 100644 internal/application/agent/config.go rename internal/{services => application}/agent/sender.go (100%) rename internal/{services => application}/agent/sender_test.go (100%) rename internal/{services => application}/metrics/metrics.go (87%) rename internal/{services => application}/metrics/metrics_benchmark_test.go (85%) rename internal/{services => application}/metrics/metrics_test.go (95%) create mode 100644 internal/domain/audit_event.go create mode 100644 internal/infra/audit/fanout.go create mode 100644 internal/infra/audit/fanout_test.go rename internal/{adapters => infra}/audit/file/file.go (100%) rename internal/{adapters => infra}/audit/file/writer.go (81%) rename internal/{adapters => infra}/audit/file/writer_test.go (58%) rename internal/{adapters => infra}/audit/remote/client.go (87%) rename internal/{adapters => infra}/audit/remote/client_test.go (70%) rename internal/{adapters => infra}/audit/remote/remote.go (100%) rename internal/{adapters => infra}/collector/runtime/keys.go (100%) rename internal/{adapters => infra}/collector/runtime/runtime.go (100%) rename internal/{adapters => infra}/collector/runtime/runtime_test.go (100%) rename internal/{adapters => infra}/collector/runtime/state.go (100%) rename internal/{ => infra}/config/agent.go (91%) rename internal/{ => infra}/config/agent_test.go (99%) rename internal/{ => infra}/config/config.go (100%) rename internal/{ => infra}/config/helpers.go (68%) rename internal/{ => infra}/config/helpers_test.go (98%) rename internal/{ => infra}/config/server.go (93%) rename internal/{ => infra}/config/server_test.go (98%) create mode 100644 internal/infra/crypto/rsaenvelope/envelope.go create mode 100644 internal/infra/crypto/rsaenvelope/envelope_test.go rename internal/{adapters => infra}/http/ginserver/example_test.go (90%) rename internal/{adapters => infra}/http/ginserver/ginserver.go (100%) rename internal/{adapters => infra}/http/ginserver/handler.go (93%) rename internal/{adapters => infra}/http/ginserver/handler_benchmark_test.go (98%) rename internal/{adapters => infra}/http/ginserver/handler_test.go (99%) create mode 100644 internal/infra/http/ginserver/middlewares/crypto_decrypt.go create mode 100644 internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go rename internal/{adapters => infra}/http/ginserver/middlewares/gin_gzip.go (100%) rename internal/{adapters => infra}/http/ginserver/middlewares/hashsha256.go (86%) rename internal/{adapters => infra}/http/ginserver/middlewares/middlewares.go (100%) create mode 100644 internal/infra/http/ginserver/middlewares/middlewares_test.go rename internal/{adapters => infra}/http/ginserver/middlewares/zap_logger.go (100%) rename internal/{adapters => infra}/http/ginserver/router.go (100%) rename internal/{adapters => infra}/persistence/file/file.go (100%) rename internal/{adapters => infra}/persistence/file/file_test.go (97%) rename internal/{adapters => infra}/publisher/httpjson/client.go (86%) rename internal/{adapters => infra}/publisher/httpjson/client_benchmark_test.go (95%) rename internal/{adapters => infra}/publisher/httpjson/client_test.go (82%) create mode 100644 internal/infra/publisher/httpjson/hash.go rename internal/{adapters => infra}/publisher/httpjson/httpjson.go (100%) rename internal/{adapters => infra}/repository/memory/memory.go (100%) rename internal/{adapters => infra}/repository/memory/memory_test.go (100%) rename internal/{adapters => infra}/repository/postgres/migrate.go (100%) rename internal/{adapters => infra}/repository/postgres/migrate_test.go (100%) rename internal/{adapters => infra}/repository/postgres/migrations/0001_init.sql (100%) rename internal/{adapters => infra}/repository/postgres/postgres.go (92%) rename internal/{adapters => infra}/repository/postgres/postgres_test.go (93%) rename internal/{misc => infra/retry}/retry.go (98%) create mode 100644 internal/infra/retry/retry_test.go delete mode 100644 internal/misc/env.go delete mode 100644 internal/misc/env_test.go delete mode 100644 internal/misc/hash.go delete mode 100644 internal/misc/hash_test.go delete mode 100644 internal/misc/misc.go delete mode 100644 internal/misc/pool.go delete mode 100644 internal/misc/pool_test.go delete mode 100644 internal/misc/retry_test.go create mode 100644 internal/ports/audit.go create mode 100644 internal/ports/crypto.go delete mode 100644 internal/services/audit/audit.go delete mode 100644 internal/services/audit/context.go delete mode 100644 internal/services/audit/context_test.go delete mode 100644 internal/services/audit/event.go delete mode 100644 internal/services/audit/subject.go delete mode 100644 internal/services/audit/subject_test.go diff --git a/README.md b/README.md index c558ba8..41e6ae2 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ You can use ENV, CLI flags, or defaults (ENV > CLI > defaults). | File storage | `FILE_STORAGE_PATH` | `-f` | `metrics-db.json` | JSON snapshot file | | Postgres DSN | `DATABASE_DSN` | `-d` | *empty* | e.g. `postgres://user:pass@localhost:5432/db?sslmode=disable` | | Secret key | `KEY` | `-k` | *empty* | enables `HashSHA256` | +| Crypto key | `CRYPTO_KEY` | `-crypto-key` | *empty* | path to RSA private key for decrypting agent payloads | | Store interval | `STORE_INTERVAL` | `-i` | `300s` | `0` = sync writes | | Restore on start | `RESTORE` | `-r` | `false` | load from file at boot | | Audit file | `AUDIT_FILE` | `--audit-file` | *empty* | newline-delimited JSON audit log fan-out target (disabled when empty) | @@ -139,6 +140,7 @@ You can use ENV, CLI flags, or defaults (ENV > CLI > defaults). | --------------- | ----------------- | ---- | ----------------------- | ----------------------- | | Server address | `ADDRESS` | `-a` | `http://localhost:8080` | URL or `host:port` | | Secret key | `KEY` | `-k` | *empty* | adds `HashSHA256` | +| Crypto key | `CRYPTO_KEY` | `-crypto-key` | *empty* | path to RSA public key for encrypting requests | | Report interval | `REPORT_INTERVAL` | `-r` | `10s` | send frequency | | Poll interval | `POLL_INTERVAL` | `-p` | `2s` | sample frequency | | Rate limit | `RATE_LIMIT` | `-l` | `1` | concurrent send workers | diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 861c232..fd3f0aa 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -8,10 +8,12 @@ import ( "os/signal" "syscall" - "github.com/vshulcz/Golectra/internal/adapters/collector/runtime" - "github.com/vshulcz/Golectra/internal/adapters/publisher/httpjson" - "github.com/vshulcz/Golectra/internal/config" - agentsvc "github.com/vshulcz/Golectra/internal/services/agent" + agentsvc "github.com/vshulcz/Golectra/internal/application/agent" + "github.com/vshulcz/Golectra/internal/infra/collector/runtime" + "github.com/vshulcz/Golectra/internal/infra/config" + "github.com/vshulcz/Golectra/internal/infra/crypto/rsaenvelope" + "github.com/vshulcz/Golectra/internal/infra/publisher/httpjson" + "github.com/vshulcz/Golectra/internal/ports" "github.com/vshulcz/Golectra/pkg/util" ) @@ -29,12 +31,26 @@ func main() { log.Fatalf("failed to parse flags: %v", err) } - pub, err := httpjson.New(cfg.Address, &http.Client{}, cfg.Key) + var encrypter ports.PayloadEncrypter + if cfg.CryptoKey != "" { + key, err := rsaenvelope.LoadPublicKey(cfg.CryptoKey) + if err != nil { + log.Fatalf("failed to load crypto key: %v", err) + } + encrypter = rsaenvelope.NewEncrypter(key) + } + + pub, err := httpjson.New(cfg.Address, &http.Client{}, cfg.Key, encrypter) if err != nil { log.Fatalf("failed to init publisher: %v", err) } collector := runtime.New() - runner := agentsvc.New(cfg, collector, pub) + appCfg := agentsvc.Config{ + PollInterval: cfg.PollInterval, + ReportInterval: cfg.ReportInterval, + RateLimit: cfg.RateLimit, + } + runner := agentsvc.New(appCfg, collector, pub) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/cmd/server/init.go b/cmd/server/init.go index b572bfc..2b83014 100644 --- a/cmd/server/init.go +++ b/cmd/server/init.go @@ -7,11 +7,11 @@ import ( _ "github.com/lib/pq" "go.uber.org/zap" - "github.com/vshulcz/Golectra/internal/adapters/persistence/file" - memrepo "github.com/vshulcz/Golectra/internal/adapters/repository/memory" - pgrepo "github.com/vshulcz/Golectra/internal/adapters/repository/postgres" - "github.com/vshulcz/Golectra/internal/config" - "github.com/vshulcz/Golectra/internal/misc" + "github.com/vshulcz/Golectra/internal/infra/config" + "github.com/vshulcz/Golectra/internal/infra/persistence/file" + memrepo "github.com/vshulcz/Golectra/internal/infra/repository/memory" + pgrepo "github.com/vshulcz/Golectra/internal/infra/repository/postgres" + "github.com/vshulcz/Golectra/internal/infra/retry" "github.com/vshulcz/Golectra/internal/ports" ) @@ -26,7 +26,7 @@ func buildRepoAndPersister(cfg config.ServerConfig, logger *zap.Logger) (ports.M } return pgrepo.Migrate(db) } - if err = misc.Retry(ctx, misc.DefaultBackoff, pgrepo.IsRetryable, op); err == nil { + if err = retry.Retry(ctx, retry.DefaultBackoff, pgrepo.IsRetryable, op); err == nil { logger.Info("db connected & migrated") return pgrepo.New(db), nil } diff --git a/cmd/server/main.go b/cmd/server/main.go index 738eef9..d62b97a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,14 +8,16 @@ import ( "os" "time" - auditfile "github.com/vshulcz/Golectra/internal/adapters/audit/file" - auditremote "github.com/vshulcz/Golectra/internal/adapters/audit/remote" - "github.com/vshulcz/Golectra/internal/adapters/http/ginserver" - "github.com/vshulcz/Golectra/internal/adapters/http/ginserver/middlewares" - "github.com/vshulcz/Golectra/internal/config" + "github.com/vshulcz/Golectra/internal/application/metrics" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/services/audit" - "github.com/vshulcz/Golectra/internal/services/metrics" + auditinfra "github.com/vshulcz/Golectra/internal/infra/audit" + auditfile "github.com/vshulcz/Golectra/internal/infra/audit/file" + auditremote "github.com/vshulcz/Golectra/internal/infra/audit/remote" + "github.com/vshulcz/Golectra/internal/infra/config" + "github.com/vshulcz/Golectra/internal/infra/crypto/rsaenvelope" + "github.com/vshulcz/Golectra/internal/infra/http/ginserver" + "github.com/vshulcz/Golectra/internal/infra/http/ginserver/middlewares" + "github.com/vshulcz/Golectra/internal/ports" "github.com/vshulcz/Golectra/pkg/util" "go.uber.org/zap" ) @@ -40,89 +42,135 @@ func run(args []string) error { return err } - logger, err := zap.NewProduction() + logger, cleanup, err := initLogger() if err != nil { return err } - defer func() { - if cerr := logger.Sync(); cerr != nil { - log.Printf("logger sync: %v", cerr) - } - }() + defer cleanup() repo, persister := buildRepoAndPersister(cfg, logger) - onChanged := func(ctx context.Context, s domain.Snapshot) { - if persister != nil { - if err := persister.Save(ctx, s); err != nil { - logger.Warn("save failed", zap.Error(err)) - } - } - } + onChanged := buildSnapshotHook(persister, logger) auditor := buildAuditor(cfg, logger) svc := metrics.New(repo, onChanged, auditor) defer svc.Close() h := ginserver.NewHandler(svc) + decrypter, err := loadDecrypter(cfg) + if err != nil { + return err + } + r := ginserver.NewRouter(h, logger, middlewares.ZapLogger(logger), + middlewares.DecryptPayload(decrypter), middlewares.GzipRequest(), middlewares.GzipResponse(), middlewares.HashSHA256(cfg.Key), ) - log.Printf("cfg: addr=%s file=%s interval=%v restore=%v dsn=%q audit_file=%q audit_url=%q", - cfg.Address, cfg.File, cfg.Interval, cfg.Restore, cfg.DSN, cfg.AuditFile, cfg.AuditURL) - - if cfg.DSN == "" && cfg.Interval > 0 { - if cfg.Interval < 0 { - cfg.Interval = 300 * time.Second - } - ticker := time.NewTicker(cfg.Interval) - go func() { - for range ticker.C { - if s, err := repo.Snapshot(context.Background()); err == nil && persister != nil { - if err := persister.Save(context.Background(), s); err != nil { - logger.Warn("periodic save failed", zap.Error(err)) - } - } - } - }() - } + logConfig(cfg) + startPeriodicSave(cfg, repo, persister, logger) - srv := &http.Server{ - Addr: cfg.Address, - Handler: r, - ReadHeaderTimeout: 5 * time.Second, - ReadTimeout: 15 * time.Second, - WriteTimeout: 15 * time.Second, - IdleTimeout: 60 * time.Second, - } - if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - return err - } - return nil + srv := newHTTPServer(cfg, r) + return serve(srv) } -func buildAuditor(cfg config.ServerConfig, logger *zap.Logger) audit.Publisher { +func buildAuditor(cfg config.ServerConfig, logger *zap.Logger) ports.AuditPublisher { if cfg.AuditFile == "" && cfg.AuditURL == "" { return nil } - subject := audit.NewSubject() - subject.SetErrorHandler(func(err error) { + fanout := auditinfra.NewFanout() + fanout.SetErrorHandler(func(err error) { logger.Warn("audit delivery failed", zap.Error(err)) }) if cfg.AuditFile != "" { - subject.Attach(auditfile.New(cfg.AuditFile)) + fanout.Attach(auditfile.New(cfg.AuditFile)) } if cfg.AuditURL != "" { client, err := auditremote.New(cfg.AuditURL, nil) if err != nil { logger.Fatal("invalid audit url", zap.Error(err)) } - subject.Attach(client) + fanout.Attach(client) + } + return fanout +} + +func initLogger() (*zap.Logger, func(), error) { + logger, err := zap.NewProduction() + if err != nil { + return nil, nil, err + } + cleanup := func() { + if cerr := logger.Sync(); cerr != nil { + log.Printf("logger sync: %v", cerr) + } + } + return logger, cleanup, nil +} + +func buildSnapshotHook(persister ports.Persister, logger *zap.Logger) func(context.Context, domain.Snapshot) { + return func(ctx context.Context, s domain.Snapshot) { + if persister == nil { + return + } + if err := persister.Save(ctx, s); err != nil { + logger.Warn("save failed", zap.Error(err)) + } + } +} + +func loadDecrypter(cfg config.ServerConfig) (ports.PayloadDecrypter, error) { + if cfg.CryptoKey == "" { + return nil, nil + } + key, err := rsaenvelope.LoadPrivateKey(cfg.CryptoKey) + if err != nil { + return nil, err + } + return rsaenvelope.NewDecrypter(key), nil +} + +func logConfig(cfg config.ServerConfig) { + log.Printf("cfg: addr=%s file=%s interval=%v restore=%v dsn=%q audit_file=%q audit_url=%q", + cfg.Address, cfg.File, cfg.Interval, cfg.Restore, cfg.DSN, cfg.AuditFile, cfg.AuditURL) +} + +func startPeriodicSave(cfg config.ServerConfig, repo ports.MetricsRepo, persister ports.Persister, logger *zap.Logger) { + if cfg.DSN != "" || cfg.Interval <= 0 { + return + } + ticker := time.NewTicker(cfg.Interval) + go func() { + for range ticker.C { + snap, err := repo.Snapshot(context.Background()) + if err != nil || persister == nil { + continue + } + if err := persister.Save(context.Background(), snap); err != nil { + logger.Warn("periodic save failed", zap.Error(err)) + } + } + }() +} + +func newHTTPServer(cfg config.ServerConfig, handler http.Handler) *http.Server { + return &http.Server{ + Addr: cfg.Address, + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, } - return subject +} + +func serve(srv *http.Server) error { + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil } func printBuildInfo() { diff --git a/internal/services/agent/agent.go b/internal/application/agent/agent.go similarity index 94% rename from internal/services/agent/agent.go rename to internal/application/agent/agent.go index a514b08..76ac398 100644 --- a/internal/services/agent/agent.go +++ b/internal/application/agent/agent.go @@ -6,7 +6,6 @@ import ( "log" "time" - "github.com/vshulcz/Golectra/internal/config" "github.com/vshulcz/Golectra/internal/domain" "github.com/vshulcz/Golectra/internal/ports" ) @@ -15,14 +14,14 @@ import ( type Service struct { collector ports.MetricsCollector pub ports.Publisher - cfg config.AgentConfig + cfg Config sender *BatchPublisher batchBuf []domain.Metrics } // New wires together the agent configuration, collector, and publisher. -func New(cfg config.AgentConfig, c ports.MetricsCollector, p ports.Publisher) *Service { +func New(cfg Config, c ports.MetricsCollector, p ports.Publisher) *Service { return &Service{cfg: cfg, collector: c, pub: p} } diff --git a/internal/services/agent/agent_benchmark_test.go b/internal/application/agent/agent_benchmark_test.go similarity index 94% rename from internal/services/agent/agent_benchmark_test.go rename to internal/application/agent/agent_benchmark_test.go index 1056065..f0fea40 100644 --- a/internal/services/agent/agent_benchmark_test.go +++ b/internal/application/agent/agent_benchmark_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/vshulcz/Golectra/internal/config" "github.com/vshulcz/Golectra/internal/domain" ) @@ -43,7 +42,7 @@ func BenchmarkAgentReportOnce(b *testing.B) { svc := &Service{ collector: &benchCollector{gauges: gauges, counters: counters}, pub: benchPublisher{}, - cfg: config.AgentConfig{ + cfg: Config{ RateLimit: 4, ReportInterval: time.Second, PollInterval: 200 * time.Millisecond, diff --git a/internal/services/agent/agent_test.go b/internal/application/agent/agent_test.go similarity index 96% rename from internal/services/agent/agent_test.go rename to internal/application/agent/agent_test.go index ef352ca..71fb646 100644 --- a/internal/services/agent/agent_test.go +++ b/internal/application/agent/agent_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/vshulcz/Golectra/internal/config" "github.com/vshulcz/Golectra/internal/domain" ) @@ -197,7 +196,7 @@ func indexByID(items []domain.Metrics) map[string]domain.Metrics { func TestService_Run_Error(t *testing.T) { coll := &fakeCollector{startErr: errors.New("nope")} pub := &fakePublisher{} - cfg := config.AgentConfig{PollInterval: 1 * time.Millisecond, ReportInterval: 2 * time.Millisecond} + cfg := Config{PollInterval: 1 * time.Millisecond, ReportInterval: 2 * time.Millisecond} svc := New(cfg, coll, pub) @@ -219,7 +218,7 @@ func TestService_Run(t *testing.T) { } pub := &fakePublisher{} - cfg := config.AgentConfig{ + cfg := Config{ PollInterval: 1 * time.Millisecond, ReportInterval: 5 * time.Millisecond, } @@ -258,7 +257,7 @@ func TestService_Run_EmptySnapshot(t *testing.T) { counters: map[string]int64{}, } pub := &fakePublisher{} - cfg := config.AgentConfig{PollInterval: 1 * time.Millisecond, ReportInterval: 5 * time.Millisecond} + cfg := Config{PollInterval: 1 * time.Millisecond, ReportInterval: 5 * time.Millisecond} svc := New(cfg, coll, pub) ctx, cancel := context.WithCancel(context.Background()) @@ -276,7 +275,7 @@ func TestService_Run_EmptySnapshot(t *testing.T) { } func TestService_Run_RespectsRateLimit(t *testing.T) { - cfg := config.AgentConfig{ + cfg := Config{ PollInterval: 1 * time.Millisecond, ReportInterval: 1 * time.Millisecond, RateLimit: 2, diff --git a/internal/application/agent/config.go b/internal/application/agent/config.go new file mode 100644 index 0000000..db68bdf --- /dev/null +++ b/internal/application/agent/config.go @@ -0,0 +1,10 @@ +package agent + +import "time" + +// Config holds runtime parameters needed by the agent application service. +type Config struct { + PollInterval time.Duration + ReportInterval time.Duration + RateLimit int +} diff --git a/internal/services/agent/sender.go b/internal/application/agent/sender.go similarity index 100% rename from internal/services/agent/sender.go rename to internal/application/agent/sender.go diff --git a/internal/services/agent/sender_test.go b/internal/application/agent/sender_test.go similarity index 100% rename from internal/services/agent/sender_test.go rename to internal/application/agent/sender_test.go diff --git a/internal/services/metrics/metrics.go b/internal/application/metrics/metrics.go similarity index 87% rename from internal/services/metrics/metrics.go rename to internal/application/metrics/metrics.go index 9589f43..80095ff 100644 --- a/internal/services/metrics/metrics.go +++ b/internal/application/metrics/metrics.go @@ -11,14 +11,13 @@ import ( "github.com/vshulcz/Golectra/internal/domain" "github.com/vshulcz/Golectra/internal/ports" - "github.com/vshulcz/Golectra/internal/services/audit" ) // Service exposes business operations for querying and mutating metrics. type Service struct { repo ports.MetricsRepo onChanged func(context.Context, domain.Snapshot) - auditor audit.Publisher + auditor ports.AuditPublisher now func() time.Time auditQueue chan auditEvent @@ -27,7 +26,7 @@ type Service struct { } // New builds a metrics Service with repository, snapshot hook, and optional auditor. -func New(repo ports.MetricsRepo, onChanged func(context.Context, domain.Snapshot), auditor audit.Publisher) *Service { +func New(repo ports.MetricsRepo, onChanged func(context.Context, domain.Snapshot), auditor ports.AuditPublisher) *Service { s := &Service{repo: repo, onChanged: onChanged, auditor: auditor, now: time.Now} s.initAuditDispatcher() return s @@ -35,7 +34,7 @@ func New(repo ports.MetricsRepo, onChanged func(context.Context, domain.Snapshot type auditEvent struct { ctx context.Context - evt audit.Event + evt domain.AuditEvent } const auditQueueSize = 128 @@ -55,7 +54,7 @@ func (s *Service) initAuditDispatcher() { case <-ctx.Done(): return case msg := <-s.auditQueue: - s.auditor.Publish(msg.ctx, msg.evt) + _ = s.auditor.Publish(msg.ctx, msg.evt) } } }() @@ -102,7 +101,7 @@ func (s *Service) Get(ctx context.Context, mType, id string) (domain.Metrics, er } // Upsert validates and stores one gauge or counter value. -func (s *Service) Upsert(ctx context.Context, m domain.Metrics) (domain.Metrics, error) { +func (s *Service) Upsert(ctx context.Context, m domain.Metrics, clientIP string) (domain.Metrics, error) { m.ID = strings.TrimSpace(m.ID) if m.ID == "" { return domain.Metrics{}, domain.ErrNotFound @@ -117,7 +116,7 @@ func (s *Service) Upsert(ctx context.Context, m domain.Metrics) (domain.Metrics, } res, err := s.Get(ctx, m.MType, m.ID) if err == nil { - s.notifyAudit(ctx, []string{m.ID}) + s.notifyAudit(ctx, []string{m.ID}, clientIP) } return res, err case string(domain.Counter): @@ -129,7 +128,7 @@ func (s *Service) Upsert(ctx context.Context, m domain.Metrics) (domain.Metrics, } res, err := s.Get(ctx, m.MType, m.ID) if err == nil { - s.notifyAudit(ctx, []string{m.ID}) + s.notifyAudit(ctx, []string{m.ID}, clientIP) } return res, err default: @@ -138,7 +137,7 @@ func (s *Service) Upsert(ctx context.Context, m domain.Metrics) (domain.Metrics, } // UpsertBatch applies many metrics in a single repository call and triggers snapshot callbacks. -func (s *Service) UpsertBatch(ctx context.Context, items []domain.Metrics) (int, error) { +func (s *Service) UpsertBatch(ctx context.Context, items []domain.Metrics, clientIP string) (int, error) { valid := make([]domain.Metrics, 0, len(items)) names := make([]string, 0, len(items)) for _, it := range items { @@ -168,7 +167,7 @@ func (s *Service) UpsertBatch(ctx context.Context, items []domain.Metrics) (int, if err := s.repo.UpdateMany(ctx, valid); err != nil { return 0, err } - s.notifyAudit(ctx, names) + s.notifyAudit(ctx, names, clientIP) if s.onChanged != nil { if snap, err := s.repo.Snapshot(ctx); err == nil { s.onChanged(ctx, snap) @@ -182,7 +181,7 @@ func (s *Service) Snapshot(ctx context.Context) (domain.Snapshot, error) { return s.repo.Snapshot(ctx) } -func (s *Service) notifyAudit(ctx context.Context, names []string) { +func (s *Service) notifyAudit(ctx context.Context, names []string, clientIP string) { if s == nil || s.auditor == nil { return } @@ -194,17 +193,19 @@ func (s *Service) notifyAudit(ctx context.Context, names []string) { if s.now != nil { ts = s.now().Unix() } - evt := audit.Event{ + evt := domain.AuditEvent{ Timestamp: ts, Metrics: uniq, - IPAddress: audit.ClientIPFromContext(ctx), + IPAddress: clientIP, } s.enqueueAudit(ctx, evt) } -func (s *Service) enqueueAudit(ctx context.Context, evt audit.Event) { +func (s *Service) enqueueAudit(ctx context.Context, evt domain.AuditEvent) { if s.auditQueue == nil { - s.auditor.Publish(ctx, evt) + if err := s.auditor.Publish(ctx, evt); err != nil { + log.Printf("metrics: audit publish failed (%v)", err) + } return } select { diff --git a/internal/services/metrics/metrics_benchmark_test.go b/internal/application/metrics/metrics_benchmark_test.go similarity index 85% rename from internal/services/metrics/metrics_benchmark_test.go rename to internal/application/metrics/metrics_benchmark_test.go index d6d6dcd..317575d 100644 --- a/internal/services/metrics/metrics_benchmark_test.go +++ b/internal/application/metrics/metrics_benchmark_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - memrepo "github.com/vshulcz/Golectra/internal/adapters/repository/memory" "github.com/vshulcz/Golectra/internal/domain" + memrepo "github.com/vshulcz/Golectra/internal/infra/repository/memory" ) func BenchmarkServiceUpsertBatch(b *testing.B) { @@ -27,7 +27,7 @@ func BenchmarkServiceUpsertBatch(b *testing.B) { b.ReportAllocs() for b.Loop() { - if _, err := svc.UpsertBatch(ctx, items); err != nil { + if _, err := svc.UpsertBatch(ctx, items, ""); err != nil { b.Fatalf("UpsertBatch: %v", err) } } diff --git a/internal/services/metrics/metrics_test.go b/internal/application/metrics/metrics_test.go similarity index 95% rename from internal/services/metrics/metrics_test.go rename to internal/application/metrics/metrics_test.go index ff55fd6..0c727f2 100644 --- a/internal/services/metrics/metrics_test.go +++ b/internal/application/metrics/metrics_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/services/audit" ) type fakeRepo struct { @@ -41,24 +40,25 @@ type fakeRepo struct { type fakeAuditor struct { mu sync.Mutex - events []audit.Event + events []domain.AuditEvent } -func (f *fakeAuditor) Publish(_ context.Context, evt audit.Event) { +func (f *fakeAuditor) Publish(_ context.Context, evt domain.AuditEvent) error { f.mu.Lock() defer f.mu.Unlock() f.events = append(f.events, evt) + return nil } -func (f *fakeAuditor) Events() []audit.Event { +func (f *fakeAuditor) Events() []domain.AuditEvent { f.mu.Lock() defer f.mu.Unlock() - out := make([]audit.Event, len(f.events)) + out := make([]domain.AuditEvent, len(f.events)) copy(out, f.events) return out } -func (f *fakeAuditor) WaitForEvents(n int, timeout time.Duration) []audit.Event { +func (f *fakeAuditor) WaitForEvents(n int, timeout time.Duration) []domain.AuditEvent { deadline := time.Now().Add(timeout) for { if events := f.Events(); len(events) >= n || time.Now().After(deadline) { @@ -279,7 +279,7 @@ func TestService_Upsert(t *testing.T) { if tc.setup != nil { tc.setup() } - got, err := svc.Upsert(context.Background(), tc.m) + got, err := svc.Upsert(context.Background(), tc.m, "") if (err == nil) != tc.wantOK { t.Fatalf("err=%v wantOK=%v", err, tc.wantOK) } @@ -341,7 +341,7 @@ func TestService_UpsertBatch(t *testing.T) { cbCalls = 0 cbMu.Unlock() - n, err := svc.UpsertBatch(context.Background(), invalids) + n, err := svc.UpsertBatch(context.Background(), invalids, "") if n != 0 || !errors.Is(err, domain.ErrInvalidType) { t.Fatalf("n=%d err=%v want 0, ErrInvalidType", n, err) } @@ -361,7 +361,7 @@ func TestService_UpsertBatch(t *testing.T) { repo.nextSnapshotErr = nil in := append([]domain.Metrics{validGauge, validCounter}, invalids...) - n, err := svc.UpsertBatch(context.Background(), in) + n, err := svc.UpsertBatch(context.Background(), in, "") if err != nil || n != 2 { t.Fatalf("n=%d err=%v want 2, nil", n, err) } @@ -384,7 +384,7 @@ func TestService_UpsertBatch(t *testing.T) { cbCalls = 0 cbMu.Unlock() - n, err := svc.UpsertBatch(context.Background(), []domain.Metrics{validGauge}) + n, err := svc.UpsertBatch(context.Background(), []domain.Metrics{validGauge}, "") if n != 0 || err == nil || err.Error() != "fail" { t.Fatalf("n=%d err=%v want 0, fail", n, err) } @@ -400,7 +400,7 @@ func TestService_UpsertBatch(t *testing.T) { cbCalls = 0 cbMu.Unlock() - n, err := svc.UpsertBatch(context.Background(), []domain.Metrics{validCounter}) + n, err := svc.UpsertBatch(context.Background(), []domain.Metrics{validCounter}, "") if err != nil || n != 1 { t.Fatalf("n=%d err=%v want 1, nil", n, err) } @@ -439,8 +439,8 @@ func TestService_Upsert_EmitsAudit(t *testing.T) { t.Cleanup(svc.Close) svc.now = func() time.Time { return time.Unix(99, 0) } - ctx := audit.WithClientIP(context.Background(), "10.0.0.1") - _, err := svc.Upsert(ctx, domain.Metrics{ID: "Alloc", MType: string(domain.Gauge), Value: ptrFloat64(1.0)}) + ctx := context.Background() + _, err := svc.Upsert(ctx, domain.Metrics{ID: "Alloc", MType: string(domain.Gauge), Value: ptrFloat64(1.0)}, "10.0.0.1") if err != nil { t.Fatalf("Upsert err: %v", err) } @@ -474,8 +474,8 @@ func TestService_UpsertBatch_EmitsAuditDedup(t *testing.T) { {ID: "A", MType: string(domain.Counter), Delta: ptrInt(3)}, {ID: " ", MType: string(domain.Gauge), Value: ptrFloat64(1)}, } - ctx := audit.WithClientIP(context.Background(), "192.0.2.1") - updated, err := svc.UpsertBatch(ctx, items) + ctx := context.Background() + updated, err := svc.UpsertBatch(ctx, items, "192.0.2.1") if err != nil { t.Fatalf("UpsertBatch err: %v", err) } diff --git a/internal/domain/audit_event.go b/internal/domain/audit_event.go new file mode 100644 index 0000000..b6eafed --- /dev/null +++ b/internal/domain/audit_event.go @@ -0,0 +1,8 @@ +package domain + +// AuditEvent describes which metrics changed, when, and from which IP address. +type AuditEvent struct { + Timestamp int64 `json:"ts"` + Metrics []string `json:"metrics"` + IPAddress string `json:"ip_address"` +} diff --git a/internal/infra/audit/fanout.go b/internal/infra/audit/fanout.go new file mode 100644 index 0000000..bc48dae --- /dev/null +++ b/internal/infra/audit/fanout.go @@ -0,0 +1,69 @@ +package audit + +import ( + "context" + "sync" + + "github.com/vshulcz/Golectra/internal/domain" + "github.com/vshulcz/Golectra/internal/ports" +) + +// Fanout publishes audit events to multiple sinks. +type Fanout struct { + mu sync.RWMutex + sinks []ports.AuditPublisher + onError func(error) +} + +// NewFanout creates a fanout publisher with optional initial sinks. +func NewFanout(sinks ...ports.AuditPublisher) *Fanout { + cp := append([]ports.AuditPublisher(nil), sinks...) + return &Fanout{sinks: cp} +} + +// Publish forwards the event to all registered sinks. +func (f *Fanout) Publish(ctx context.Context, evt domain.AuditEvent) error { + if f == nil { + return nil + } + f.mu.RLock() + sinks := append([]ports.AuditPublisher(nil), f.sinks...) + errHandler := f.onError + f.mu.RUnlock() + + var firstErr error + for _, sink := range sinks { + if sink == nil { + continue + } + if err := sink.Publish(ctx, evt); err != nil { + if firstErr == nil { + firstErr = err + } + if errHandler != nil { + errHandler(err) + } + } + } + return firstErr +} + +// Attach registers additional audit sinks. +func (f *Fanout) Attach(sinks ...ports.AuditPublisher) { + if f == nil || len(sinks) == 0 { + return + } + f.mu.Lock() + f.sinks = append(f.sinks, sinks...) + f.mu.Unlock() +} + +// SetErrorHandler configures a callback for sink failures. +func (f *Fanout) SetErrorHandler(fn func(error)) { + if f == nil { + return + } + f.mu.Lock() + f.onError = fn + f.mu.Unlock() +} diff --git a/internal/infra/audit/fanout_test.go b/internal/infra/audit/fanout_test.go new file mode 100644 index 0000000..19a9878 --- /dev/null +++ b/internal/infra/audit/fanout_test.go @@ -0,0 +1,105 @@ +package audit + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/vshulcz/Golectra/internal/domain" + "github.com/vshulcz/Golectra/internal/ports" +) + +type fakeSink struct { + count int32 + err error +} + +func (s *fakeSink) Publish(_ context.Context, _ domain.AuditEvent) error { + atomic.AddInt32(&s.count, 1) + return s.err +} + +func TestFanout_Publish_Success(t *testing.T) { + s1 := &fakeSink{} + s2 := &fakeSink{} + f := NewFanout(s1, s2) + + evt := domain.AuditEvent{Timestamp: 1, Metrics: []string{"A"}, IPAddress: "1.1.1.1"} + if err := f.Publish(context.Background(), evt); err != nil { + t.Fatalf("Publish error: %v", err) + } + if got := atomic.LoadInt32(&s1.count); got != 1 { + t.Fatalf("sink1 count=%d want 1", got) + } + if got := atomic.LoadInt32(&s2.count); got != 1 { + t.Fatalf("sink2 count=%d want 1", got) + } +} + +func TestFanout_Publish_ErrorHandlerAndReturn(t *testing.T) { + wantErr := errors.New("boom") + s1 := &fakeSink{err: wantErr} + s2 := &fakeSink{} + f := NewFanout(s1, s2) + + var errCalls int32 + f.SetErrorHandler(func(err error) { + if !errors.Is(err, wantErr) { + t.Fatalf("unexpected error: %v", err) + } + atomic.AddInt32(&errCalls, 1) + }) + + err := f.Publish(context.Background(), domain.AuditEvent{}) + if !errors.Is(err, wantErr) { + t.Fatalf("Publish error=%v want %v", err, wantErr) + } + if got := atomic.LoadInt32(&errCalls); got != 1 { + t.Fatalf("errCalls=%d want 1", got) + } + if got := atomic.LoadInt32(&s2.count); got != 1 { + t.Fatalf("sink2 count=%d want 1", got) + } +} + +func TestFanout_Attach_IgnoresNil(t *testing.T) { + var nilSink ports.AuditPublisher + s1 := &fakeSink{} + f := NewFanout() + f.Attach(nilSink, s1) + + if err := f.Publish(context.Background(), domain.AuditEvent{}); err != nil { + t.Fatalf("Publish error: %v", err) + } + if got := atomic.LoadInt32(&s1.count); got != 1 { + t.Fatalf("sink count=%d want 1", got) + } +} + +func TestFanout_NilReceiver(t *testing.T) { + var f *Fanout + if err := f.Publish(context.Background(), domain.AuditEvent{}); err != nil { + t.Fatalf("Publish error: %v", err) + } + f.Attach(nil) + f.SetErrorHandler(nil) +} + +func TestFanout_MultipleErrors(t *testing.T) { + wantErr := errors.New("boom") + s1 := &fakeSink{err: wantErr} + s2 := &fakeSink{err: wantErr} + f := NewFanout(s1, s2) + + err := f.Publish(context.Background(), domain.AuditEvent{}) + if !errors.Is(err, wantErr) { + t.Fatalf("Publish error=%v want %v", err, wantErr) + } + if got := atomic.LoadInt32(&s1.count); got != 1 { + t.Fatalf("sink1 count=%d want 1", got) + } + if got := atomic.LoadInt32(&s2.count); got != 1 { + t.Fatalf("sink2 count=%d want 1", got) + } +} diff --git a/internal/adapters/audit/file/file.go b/internal/infra/audit/file/file.go similarity index 100% rename from internal/adapters/audit/file/file.go rename to internal/infra/audit/file/file.go diff --git a/internal/adapters/audit/file/writer.go b/internal/infra/audit/file/writer.go similarity index 81% rename from internal/adapters/audit/file/writer.go rename to internal/infra/audit/file/writer.go index 1a6caae..fb16899 100644 --- a/internal/adapters/audit/file/writer.go +++ b/internal/infra/audit/file/writer.go @@ -7,7 +7,7 @@ import ( "os" "sync" - "github.com/vshulcz/Golectra/internal/services/audit" + "github.com/vshulcz/Golectra/internal/domain" ) // Writer appends audit events to a local newline-delimited JSON file. @@ -21,8 +21,8 @@ func New(path string) *Writer { return &Writer{path: path} } -// Notify marshals the audit event and atomically appends it to the writer's file. -func (w *Writer) Notify(_ context.Context, evt audit.Event) (retErr error) { +// Publish marshals the audit event and atomically appends it to the writer's file. +func (w *Writer) Publish(_ context.Context, evt domain.AuditEvent) (retErr error) { if w == nil || w.path == "" { return nil } diff --git a/internal/adapters/audit/file/writer_test.go b/internal/infra/audit/file/writer_test.go similarity index 58% rename from internal/adapters/audit/file/writer_test.go rename to internal/infra/audit/file/writer_test.go index 5233cdc..a163e05 100644 --- a/internal/adapters/audit/file/writer_test.go +++ b/internal/infra/audit/file/writer_test.go @@ -6,16 +6,16 @@ import ( "os" "testing" - "github.com/vshulcz/Golectra/internal/services/audit" + "github.com/vshulcz/Golectra/internal/domain" ) -func TestWriter_Notify_AppendsJSONLine(t *testing.T) { +func TestWriter_Publish_AppendsJSONLine(t *testing.T) { tmp := t.TempDir() path := tmp + "/audit.log" w := New(path) - evt := audit.Event{Timestamp: 1, Metrics: []string{"Alloc"}, IPAddress: "127.0.0.1"} - if err := w.Notify(context.Background(), evt); err != nil { - t.Fatalf("Notify error: %v", err) + evt := domain.AuditEvent{Timestamp: 1, Metrics: []string{"Alloc"}, IPAddress: "127.0.0.1"} + if err := w.Publish(context.Background(), evt); err != nil { + t.Fatalf("Publish error: %v", err) } data, err := os.ReadFile(path) @@ -23,7 +23,7 @@ func TestWriter_Notify_AppendsJSONLine(t *testing.T) { t.Fatalf("read file: %v", err) } - var decoded audit.Event + var decoded domain.AuditEvent if err := json.Unmarshal(data[:len(data)-1], &decoded); err != nil { t.Fatalf("unmarshal: %v", err) } diff --git a/internal/adapters/audit/remote/client.go b/internal/infra/audit/remote/client.go similarity index 87% rename from internal/adapters/audit/remote/client.go rename to internal/infra/audit/remote/client.go index 7987b64..5e1df47 100644 --- a/internal/adapters/audit/remote/client.go +++ b/internal/infra/audit/remote/client.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/vshulcz/Golectra/internal/services/audit" + "github.com/vshulcz/Golectra/internal/domain" ) // Client sends audit events to a remote HTTP endpoint. @@ -34,8 +34,8 @@ func New(rawURL string, hc *http.Client) (*Client, error) { return &Client{endpoint: rawURL, hc: hc}, nil } -// Notify serializes the audit event and issues an HTTP POST to the configured endpoint. -func (c *Client) Notify(ctx context.Context, evt audit.Event) (retErr error) { +// Publish serializes the audit event and issues an HTTP POST to the configured endpoint. +func (c *Client) Publish(ctx context.Context, evt domain.AuditEvent) (retErr error) { if c == nil { return nil } diff --git a/internal/adapters/audit/remote/client_test.go b/internal/infra/audit/remote/client_test.go similarity index 70% rename from internal/adapters/audit/remote/client_test.go rename to internal/infra/audit/remote/client_test.go index 7f9da19..c995a30 100644 --- a/internal/adapters/audit/remote/client_test.go +++ b/internal/infra/audit/remote/client_test.go @@ -7,11 +7,11 @@ import ( "net/http/httptest" "testing" - "github.com/vshulcz/Golectra/internal/services/audit" + "github.com/vshulcz/Golectra/internal/domain" ) -func TestClient_Notify_OK(t *testing.T) { - var received audit.Event +func TestClient_Publish_OK(t *testing.T) { + var received domain.AuditEvent ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := r.Body.Close(); err != nil { @@ -35,9 +35,9 @@ func TestClient_Notify_OK(t *testing.T) { t.Fatalf("New error: %v", err) } - evt := audit.Event{Timestamp: 1, Metrics: []string{"Alloc"}, IPAddress: "1.1.1.1"} - if err := cli.Notify(context.Background(), evt); err != nil { - t.Fatalf("Notify error: %v", err) + evt := domain.AuditEvent{Timestamp: 1, Metrics: []string{"Alloc"}, IPAddress: "1.1.1.1"} + if err := cli.Publish(context.Background(), evt); err != nil { + t.Fatalf("Publish error: %v", err) } if received.IPAddress != evt.IPAddress { @@ -45,7 +45,7 @@ func TestClient_Notify_OK(t *testing.T) { } } -func TestClient_Notify_StatusError(t *testing.T) { +func TestClient_Publish_StatusError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadGateway) })) @@ -56,7 +56,7 @@ func TestClient_Notify_StatusError(t *testing.T) { t.Fatalf("New error: %v", err) } - if err := cli.Notify(context.Background(), audit.Event{}); err == nil { + if err := cli.Publish(context.Background(), domain.AuditEvent{}); err == nil { t.Fatal("expected error") } } diff --git a/internal/adapters/audit/remote/remote.go b/internal/infra/audit/remote/remote.go similarity index 100% rename from internal/adapters/audit/remote/remote.go rename to internal/infra/audit/remote/remote.go diff --git a/internal/adapters/collector/runtime/keys.go b/internal/infra/collector/runtime/keys.go similarity index 100% rename from internal/adapters/collector/runtime/keys.go rename to internal/infra/collector/runtime/keys.go diff --git a/internal/adapters/collector/runtime/runtime.go b/internal/infra/collector/runtime/runtime.go similarity index 100% rename from internal/adapters/collector/runtime/runtime.go rename to internal/infra/collector/runtime/runtime.go diff --git a/internal/adapters/collector/runtime/runtime_test.go b/internal/infra/collector/runtime/runtime_test.go similarity index 100% rename from internal/adapters/collector/runtime/runtime_test.go rename to internal/infra/collector/runtime/runtime_test.go diff --git a/internal/adapters/collector/runtime/state.go b/internal/infra/collector/runtime/state.go similarity index 100% rename from internal/adapters/collector/runtime/state.go rename to internal/infra/collector/runtime/state.go diff --git a/internal/config/agent.go b/internal/infra/config/agent.go similarity index 91% rename from internal/config/agent.go rename to internal/infra/config/agent.go index 6160f7b..3c62eb8 100644 --- a/internal/config/agent.go +++ b/internal/infra/config/agent.go @@ -20,6 +20,7 @@ const ( type AgentConfig struct { Address string Key string + CryptoKey string PollInterval time.Duration ReportInterval time.Duration RateLimit int @@ -36,12 +37,14 @@ func LoadAgentConfig(args []string, out io.Writer) (AgentConfig, error) { var addrOpt string var keyOpt string + var cryptoKeyOpt string var reportOpt int var pollOpt int var limitOpt int fs.StringVar(&addrOpt, "a", "", fmt.Sprintf("server address (host:port or URL), default: %s", defaultServerAddr)) fs.StringVar(&keyOpt, "k", "", "secret key for HashSHA256 header") + fs.StringVar(&cryptoKeyOpt, "crypto-key", "", "path to RSA public key for request encryption") fs.IntVar(&reportOpt, "r", 0, fmt.Sprintf("report interval in seconds, default: %d", defaultReportInterval)) fs.IntVar(&pollOpt, "p", 0, fmt.Sprintf("poll interval in seconds, default: %d", defaultPollInterval)) fs.IntVar(&limitOpt, "l", 0, "rate limit (max concurrent outgoing requests), default: 1") @@ -57,6 +60,7 @@ func LoadAgentConfig(args []string, out io.Writer) (AgentConfig, error) { } key := FromEnvOrFlag("KEY", keyOpt, "") + cryptoKey := FromEnvOrFlag("CRYPTO_KEY", cryptoKeyOpt, "") report, _ := FromEnvOrFlagDuration("REPORT_INTERVAL", reportOpt, 0, defaultReportInterval) if report <= 0 { @@ -73,6 +77,7 @@ func LoadAgentConfig(args []string, out io.Writer) (AgentConfig, error) { return AgentConfig{ Address: addr, Key: key, + CryptoKey: cryptoKey, PollInterval: poll, ReportInterval: report, RateLimit: limit, diff --git a/internal/config/agent_test.go b/internal/infra/config/agent_test.go similarity index 99% rename from internal/config/agent_test.go rename to internal/infra/config/agent_test.go index a492a36..9625173 100644 --- a/internal/config/agent_test.go +++ b/internal/infra/config/agent_test.go @@ -100,7 +100,7 @@ func TestLoadAgentConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - for _, k := range []string{"ADDRESS", "REPORT_INTERVAL", "POLL_INTERVAL"} { + for _, k := range []string{"ADDRESS", "REPORT_INTERVAL", "POLL_INTERVAL", "CRYPTO_KEY"} { t.Setenv(k, "") } for k, v := range tt.env { diff --git a/internal/config/config.go b/internal/infra/config/config.go similarity index 100% rename from internal/config/config.go rename to internal/infra/config/config.go diff --git a/internal/config/helpers.go b/internal/infra/config/helpers.go similarity index 68% rename from internal/config/helpers.go rename to internal/infra/config/helpers.go index 01a7d6e..1d9cdea 100644 --- a/internal/config/helpers.go +++ b/internal/infra/config/helpers.go @@ -5,8 +5,6 @@ import ( "strconv" "strings" "time" - - "github.com/vshulcz/Golectra/internal/misc" ) // FromEnvOrFlag returns the environment value when present, otherwise falls back to a CLI flag then default. @@ -23,7 +21,7 @@ func FromEnvOrFlag(envKey, flagVal, def string) string { // FromEnvOrFlagBool merges boolean values from ENV and flags (defaulting to def). func FromEnvOrFlagBool(envKey string, flagVal, def bool) bool { if ev := strings.TrimSpace(os.Getenv(envKey)); ev != "" { - return misc.GetBool(envKey, def) + return envBool(envKey, def) } if flagVal { return true @@ -53,10 +51,45 @@ func FromEnvOrFlagDuration(envKey string, flagSeconds, flagSentinel, defSeconds if d, err := time.ParseDuration(ev); err == nil { return d, true } - return misc.GetDuration(envKey, time.Duration(defSeconds)*time.Second), true + return envDuration(envKey, time.Duration(defSeconds)*time.Second), true } if flagSeconds != flagSentinel { return time.Duration(flagSeconds) * time.Second, true } return time.Duration(defSeconds) * time.Second, false } + +func envDuration(key string, def time.Duration) time.Duration { + v := os.Getenv(key) + if v == "" { + return def + } + if n, err := strconv.ParseInt(v, 10, 64); err == nil { + if n <= 0 { + return 0 + } + return time.Duration(n) * time.Second + } + if d, err := time.ParseDuration(v); err == nil { + if d <= 0 { + return 0 + } + return d + } + return def +} + +func envBool(key string, def bool) bool { + v := strings.TrimSpace(strings.ToLower(os.Getenv(key))) + if v == "" { + return def + } + switch v { + case "1", "true", "t", "yes", "y": + return true + case "0", "false", "f", "no", "n": + return false + default: + return def + } +} diff --git a/internal/config/helpers_test.go b/internal/infra/config/helpers_test.go similarity index 98% rename from internal/config/helpers_test.go rename to internal/infra/config/helpers_test.go index bbf7628..80b11bd 100644 --- a/internal/config/helpers_test.go +++ b/internal/infra/config/helpers_test.go @@ -213,7 +213,7 @@ func TestHelpers_FromEnvOrFlagDuration(t *testing.T) { expectCustom: true, }, { - name: "env invalid -> fallback via misc.GetDuration -> def", + name: "env invalid -> fallback via envDuration -> def", env: "not-a-duration", flagSeconds: 10, sentinel: 0, diff --git a/internal/config/server.go b/internal/infra/config/server.go similarity index 93% rename from internal/config/server.go rename to internal/infra/config/server.go index 6b91521..403f382 100644 --- a/internal/config/server.go +++ b/internal/infra/config/server.go @@ -24,6 +24,7 @@ type ServerConfig struct { File string DSN string Key string + CryptoKey string Interval time.Duration Restore bool AuditFile string @@ -43,6 +44,7 @@ func LoadServerConfig(args []string, out io.Writer) (ServerConfig, error) { var fileOpt string var dsnOpt string var keyOpt string + var cryptoKeyOpt string var ivalOpt int var restoreOpt bool var auditFileOpt string @@ -52,6 +54,7 @@ func LoadServerConfig(args []string, out io.Writer) (ServerConfig, error) { fs.StringVar(&fileOpt, "f", "", fmt.Sprintf("FILE_STORAGE_PATH, default: %s", defaultFilePath)) fs.StringVar(&dsnOpt, "d", "", fmt.Sprintf("DATABASE_DSN for Postgres, default: %s", defaultDSN)) fs.StringVar(&keyOpt, "k", "", "secret key for HashSHA256") + fs.StringVar(&cryptoKeyOpt, "crypto-key", "", "path to RSA private key for request decryption") fs.IntVar(&ivalOpt, "i", -1, fmt.Sprintf("STORE_INTERVAL seconds (0 - sync), default: %d", defaultStoreInterval)) fs.BoolVar(&restoreOpt, "r", false, fmt.Sprintf("RESTORE on start (true/false), default: %t", defaultRestore)) fs.StringVar(&auditFileOpt, "audit-file", "", "path to audit log file (disabled if empty)") @@ -70,6 +73,7 @@ func LoadServerConfig(args []string, out io.Writer) (ServerConfig, error) { file := FromEnvOrFlag("FILE_STORAGE_PATH", fileOpt, defaultFilePath) dsn := FromEnvOrFlag("DATABASE_DSN", dsnOpt, "") key := FromEnvOrFlag("KEY", keyOpt, "") + cryptoKey := FromEnvOrFlag("CRYPTO_KEY", cryptoKeyOpt, "") auditFile := FromEnvOrFlag("AUDIT_FILE", auditFileOpt, "") auditURL := FromEnvOrFlag("AUDIT_URL", auditURLOpt, "") @@ -85,6 +89,7 @@ func LoadServerConfig(args []string, out io.Writer) (ServerConfig, error) { File: file, DSN: dsn, Key: key, + CryptoKey: cryptoKey, Interval: interval, Restore: restore, AuditFile: auditFile, diff --git a/internal/config/server_test.go b/internal/infra/config/server_test.go similarity index 98% rename from internal/config/server_test.go rename to internal/infra/config/server_test.go index e3d6c67..8a13691 100644 --- a/internal/config/server_test.go +++ b/internal/infra/config/server_test.go @@ -120,7 +120,7 @@ func TestLoadServerConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - for _, k := range []string{"ADDRESS", "STORE_INTERVAL", "FILE_STORAGE_PATH", "RESTORE", "AUDIT_FILE", "AUDIT_URL"} { + for _, k := range []string{"ADDRESS", "STORE_INTERVAL", "FILE_STORAGE_PATH", "RESTORE", "AUDIT_FILE", "AUDIT_URL", "CRYPTO_KEY"} { t.Setenv(k, "") } for k, v := range tt.env { diff --git a/internal/infra/crypto/rsaenvelope/envelope.go b/internal/infra/crypto/rsaenvelope/envelope.go new file mode 100644 index 0000000..88d2407 --- /dev/null +++ b/internal/infra/crypto/rsaenvelope/envelope.go @@ -0,0 +1,218 @@ +package rsaenvelope + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "io/fs" + "math" + "os" + "path/filepath" +) + +const ( + headerKey = "X-Encrypted" + headerValue = "rsa" + + aesKeySize = 32 + nonceSize = 12 + keyLenFieldSz = 2 +) + +type cipherEnvelope struct { + pub *rsa.PublicKey + priv *rsa.PrivateKey +} + +// NewEncrypter returns an encrypter for RSA-OAEP + AES-GCM envelopes. +func NewEncrypter(pub *rsa.PublicKey) *cipherEnvelope { + if pub == nil { + return nil + } + return &cipherEnvelope{pub: pub} +} + +// NewDecrypter returns a decrypter for RSA-OAEP + AES-GCM envelopes. +func NewDecrypter(priv *rsa.PrivateKey) *cipherEnvelope { + if priv == nil { + return nil + } + return &cipherEnvelope{priv: priv} +} + +func (c *cipherEnvelope) HeaderKey() string { + return headerKey +} + +func (c *cipherEnvelope) HeaderValue() string { + return headerValue +} + +func (c *cipherEnvelope) Encrypt(plain []byte) ([]byte, error) { + if c == nil || c.pub == nil { + return nil, errors.New("encrypt: nil public key") + } + + key := make([]byte, aesKeySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("rand key: %w", err) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("rand nonce: %w", err) + } + ciphertext := gcm.Seal(nil, nonce, plain, nil) + + encKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, c.pub, key, nil) + if err != nil { + return nil, fmt.Errorf("encrypt key: %w", err) + } + keyLen, err := safeUint16(len(encKey)) + if err != nil { + return nil, errors.New("encrypt key: envelope too large") + } + out := make([]byte, keyLenFieldSz+len(encKey)+len(nonce)+len(ciphertext)) + binary.BigEndian.PutUint16(out, keyLen) + offset := keyLenFieldSz + copy(out[offset:], encKey) + offset += len(encKey) + copy(out[offset:], nonce) + offset += len(nonce) + copy(out[offset:], ciphertext) + return out, nil +} + +func safeUint16(n int) (uint16, error) { + if n < 0 || n > math.MaxUint16 { + return 0, fmt.Errorf("out of uint16 range: %d", n) + } + return uint16(n), nil +} + +func readKeyFile(path string) ([]byte, error) { + if path == "" { + return nil, errors.New("empty path") + } + abs, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(abs) + name := filepath.Base(abs) + fsys := os.DirFS(dir) + if !fs.ValidPath(name) { + return nil, fmt.Errorf("invalid key filename: %q", name) + } + return fs.ReadFile(fsys, name) +} + +func (c *cipherEnvelope) Decrypt(envelope []byte) ([]byte, error) { + if c == nil || c.priv == nil { + return nil, errors.New("decrypt: nil private key") + } + if len(envelope) < keyLenFieldSz+nonceSize { + return nil, errors.New("decrypt: envelope too short") + } + + keyLen := int(binary.BigEndian.Uint16(envelope[:keyLenFieldSz])) + offset := keyLenFieldSz + if keyLen == 0 || len(envelope) < offset+keyLen+nonceSize { + return nil, errors.New("decrypt: invalid key length") + } + encKey := envelope[offset : offset+keyLen] + offset += keyLen + nonce := envelope[offset : offset+nonceSize] + offset += nonceSize + ciphertext := envelope[offset:] + if len(ciphertext) == 0 { + return nil, errors.New("decrypt: empty ciphertext") + } + + key, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, c.priv, encKey, nil) + if err != nil { + return nil, fmt.Errorf("decrypt key: %w", err) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + if gcm.NonceSize() != nonceSize { + return nil, errors.New("decrypt: unexpected nonce size") + } + plain, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("decrypt payload: %w", err) + } + return plain, nil +} + +// LoadPublicKey loads an RSA public key from a PEM file (PKIX or PKCS1). +func LoadPublicKey(path string) (*rsa.PublicKey, error) { + data, err := readKeyFile(path) + if err != nil { + return nil, fmt.Errorf("read public key: %w", err) + } + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("decode public key PEM: no block found") + } + + ifc, err := x509.ParsePKIXPublicKey(block.Bytes) + if err == nil { + key, ok := ifc.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key: not RSA") + } + return key, nil + } + key, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse public key: %w", err) + } + return key, nil +} + +// LoadPrivateKey loads an RSA private key from a PEM file (PKCS1 or PKCS8). +func LoadPrivateKey(path string) (*rsa.PrivateKey, error) { + data, err := readKeyFile(path) + if err != nil { + return nil, fmt.Errorf("read private key: %w", err) + } + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("decode private key PEM: no block found") + } + + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err == nil { + return key, nil + } + ifc, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + priv, ok := ifc.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("private key: not RSA") + } + return priv, nil +} diff --git a/internal/infra/crypto/rsaenvelope/envelope_test.go b/internal/infra/crypto/rsaenvelope/envelope_test.go new file mode 100644 index 0000000..985a8b6 --- /dev/null +++ b/internal/infra/crypto/rsaenvelope/envelope_test.go @@ -0,0 +1,80 @@ +package rsaenvelope + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "testing" +) + +func TestEncryptDecryptRoundTrip(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + encrypter := NewEncrypter(&priv.PublicKey) + decrypter := NewDecrypter(priv) + + plain := []byte("hello encrypted world") + enc, err := encrypter.Encrypt(plain) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + dec, err := decrypter.Decrypt(enc) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("decrypt mismatch: got %q want %q", dec, plain) + } + + if encrypter.HeaderKey() != headerKey || encrypter.HeaderValue() != headerValue { + t.Fatalf("header mismatch: %s=%s", encrypter.HeaderKey(), encrypter.HeaderValue()) + } +} + +func TestLoadKeys(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + pubDER, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + if err != nil { + t.Fatalf("marshal pub: %v", err) + } + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}) + privPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + dir := t.TempDir() + pubPath := filepath.Join(dir, "pub.pem") + privPath := filepath.Join(dir, "priv.pem") + + if err := os.WriteFile(pubPath, pubPEM, 0o600); err != nil { + t.Fatalf("write pub: %v", err) + } + if err := os.WriteFile(privPath, privPEM, 0o600); err != nil { + t.Fatalf("write priv: %v", err) + } + + gotPub, err := LoadPublicKey(pubPath) + if err != nil { + t.Fatalf("load pub: %v", err) + } + gotPriv, err := LoadPrivateKey(privPath) + if err != nil { + t.Fatalf("load priv: %v", err) + } + + if gotPub.N.Cmp(priv.N) != 0 { + t.Fatal("public key mismatch") + } + if gotPriv.N.Cmp(priv.N) != 0 { + t.Fatal("private key mismatch") + } +} diff --git a/internal/adapters/http/ginserver/example_test.go b/internal/infra/http/ginserver/example_test.go similarity index 90% rename from internal/adapters/http/ginserver/example_test.go rename to internal/infra/http/ginserver/example_test.go index fc0e127..6f9b0ef 100644 --- a/internal/adapters/http/ginserver/example_test.go +++ b/internal/infra/http/ginserver/example_test.go @@ -11,10 +11,10 @@ import ( "github.com/gin-gonic/gin" "go.uber.org/zap" - "github.com/vshulcz/Golectra/internal/adapters/http/ginserver" - "github.com/vshulcz/Golectra/internal/adapters/repository/memory" + "github.com/vshulcz/Golectra/internal/application/metrics" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/services/metrics" + "github.com/vshulcz/Golectra/internal/infra/http/ginserver" + "github.com/vshulcz/Golectra/internal/infra/repository/memory" ) func newExampleRouter() *gin.Engine { diff --git a/internal/adapters/http/ginserver/ginserver.go b/internal/infra/http/ginserver/ginserver.go similarity index 100% rename from internal/adapters/http/ginserver/ginserver.go rename to internal/infra/http/ginserver/ginserver.go diff --git a/internal/adapters/http/ginserver/handler.go b/internal/infra/http/ginserver/handler.go similarity index 93% rename from internal/adapters/http/ginserver/handler.go rename to internal/infra/http/ginserver/handler.go index 4a44290..bcd9fe8 100644 --- a/internal/adapters/http/ginserver/handler.go +++ b/internal/infra/http/ginserver/handler.go @@ -11,9 +11,8 @@ import ( "sync" "github.com/gin-gonic/gin" + "github.com/vshulcz/Golectra/internal/application/metrics" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/services/audit" - "github.com/vshulcz/Golectra/internal/services/metrics" ) // Handler exposes HTTP endpoints for metric collection and inspection. @@ -91,8 +90,8 @@ func (h *Handler) UpdateMetric(c *gin.Context) { c.String(http.StatusBadRequest, "bad request") return } - ctx := audit.WithClientIP(c.Request.Context(), c.ClientIP()) - if _, err := h.svc.Upsert(ctx, m); err != nil { + ctx := c.Request.Context() + if _, err := h.svc.Upsert(ctx, m, c.ClientIP()); err != nil { httpError(c, err) return } @@ -166,8 +165,8 @@ func (h *Handler) UpdateMetricJSON(c *gin.Context) { return } - ctx := audit.WithClientIP(c.Request.Context(), c.ClientIP()) - res, err := h.svc.Upsert(ctx, m) + ctx := c.Request.Context() + res, err := h.svc.Upsert(ctx, m, c.ClientIP()) if err != nil { httpError(c, err) return @@ -200,8 +199,8 @@ func (h *Handler) UpdateMetricsBatchJSON(c *gin.Context) { } defer release() items = cloneMetrics(items) - ctx := audit.WithClientIP(c.Request.Context(), c.ClientIP()) - updated, err := h.svc.UpsertBatch(ctx, items) + ctx := c.Request.Context() + updated, err := h.svc.UpsertBatch(ctx, items, c.ClientIP()) if err != nil { httpError(c, err) return diff --git a/internal/adapters/http/ginserver/handler_benchmark_test.go b/internal/infra/http/ginserver/handler_benchmark_test.go similarity index 98% rename from internal/adapters/http/ginserver/handler_benchmark_test.go rename to internal/infra/http/ginserver/handler_benchmark_test.go index 2cdaff8..7e9305b 100644 --- a/internal/adapters/http/ginserver/handler_benchmark_test.go +++ b/internal/infra/http/ginserver/handler_benchmark_test.go @@ -12,8 +12,8 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/vshulcz/Golectra/internal/application/metrics" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/services/metrics" ) type benchRepo struct { diff --git a/internal/adapters/http/ginserver/handler_test.go b/internal/infra/http/ginserver/handler_test.go similarity index 99% rename from internal/adapters/http/ginserver/handler_test.go rename to internal/infra/http/ginserver/handler_test.go index 97a8a2d..0df1556 100644 --- a/internal/adapters/http/ginserver/handler_test.go +++ b/internal/infra/http/ginserver/handler_test.go @@ -13,13 +13,13 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/vshulcz/Golectra/internal/application/metrics" "github.com/vshulcz/Golectra/internal/domain" "github.com/vshulcz/Golectra/internal/ports" - "github.com/vshulcz/Golectra/internal/services/metrics" "go.uber.org/zap" - "github.com/vshulcz/Golectra/internal/adapters/http/ginserver/middlewares" - memrepo "github.com/vshulcz/Golectra/internal/adapters/repository/memory" + "github.com/vshulcz/Golectra/internal/infra/http/ginserver/middlewares" + memrepo "github.com/vshulcz/Golectra/internal/infra/repository/memory" ) func newServer(t *testing.T, repo ports.MetricsRepo, onChanged ...func(context.Context, domain.Snapshot)) *httptest.Server { diff --git a/internal/infra/http/ginserver/middlewares/crypto_decrypt.go b/internal/infra/http/ginserver/middlewares/crypto_decrypt.go new file mode 100644 index 0000000..3ec1b98 --- /dev/null +++ b/internal/infra/http/ginserver/middlewares/crypto_decrypt.go @@ -0,0 +1,44 @@ +package middlewares + +import ( + "bytes" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/vshulcz/Golectra/internal/ports" +) + +// DecryptPayload decrypts request bodies that carry the encryption header. +func DecryptPayload(dec ports.PayloadDecrypter) gin.HandlerFunc { + if dec == nil { + return func(c *gin.Context) { + c.Next() + } + } + + return func(c *gin.Context) { + if !strings.EqualFold(c.GetHeader(dec.HeaderKey()), dec.HeaderValue()) { + c.Next() + return + } + + raw, err := io.ReadAll(c.Request.Body) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "read body failed"}) + return + } + if err := c.Request.Body.Close(); err != nil { + _ = c.Error(err) + } + plain, err := dec.Decrypt(raw) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "decrypt failed"}) + return + } + c.Request.Body = io.NopCloser(bytes.NewReader(plain)) + c.Request.ContentLength = int64(len(plain)) + c.Next() + } +} diff --git a/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go b/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go new file mode 100644 index 0000000..d420de5 --- /dev/null +++ b/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go @@ -0,0 +1,54 @@ +package middlewares + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/vshulcz/Golectra/internal/infra/crypto/rsaenvelope" +) + +func TestDecryptPayload(t *testing.T) { + gin.SetMode(gin.TestMode) + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + encrypter := rsaenvelope.NewEncrypter(&priv.PublicKey) + decrypter := rsaenvelope.NewDecrypter(priv) + + r := gin.New() + r.Use(DecryptPayload(decrypter)) + r.POST("/decrypt", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + c.Data(http.StatusOK, "text/plain", body) + }) + + plain := []byte("payload") + enc, err := encrypter.Encrypt(plain) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/decrypt", bytes.NewReader(enc)) + req.Header.Set(encrypter.HeaderKey(), encrypter.HeaderValue()) + rec := httptest.NewRecorder() + + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if !bytes.Equal(rec.Body.Bytes(), plain) { + t.Fatalf("body=%q want %q", rec.Body.Bytes(), plain) + } +} diff --git a/internal/adapters/http/ginserver/middlewares/gin_gzip.go b/internal/infra/http/ginserver/middlewares/gin_gzip.go similarity index 100% rename from internal/adapters/http/ginserver/middlewares/gin_gzip.go rename to internal/infra/http/ginserver/middlewares/gin_gzip.go diff --git a/internal/adapters/http/ginserver/middlewares/hashsha256.go b/internal/infra/http/ginserver/middlewares/hashsha256.go similarity index 86% rename from internal/adapters/http/ginserver/middlewares/hashsha256.go rename to internal/infra/http/ginserver/middlewares/hashsha256.go index 8b4f90d..4c50cc8 100644 --- a/internal/adapters/http/ginserver/middlewares/hashsha256.go +++ b/internal/infra/http/ginserver/middlewares/hashsha256.go @@ -2,12 +2,13 @@ package middlewares import ( "bytes" + "crypto/sha256" + "encoding/hex" "io" "net/http" "strings" "github.com/gin-gonic/gin" - "github.com/vshulcz/Golectra/internal/misc" ) type bodyBufferWriter struct { @@ -50,7 +51,7 @@ func HashSHA256(key string) gin.HandlerFunc { } c.Request.Body = io.NopCloser(bytes.NewReader(reqBody)) if len(reqBody) > 0 { - want := misc.SumSHA256(reqBody, key) + want := sumSHA256(reqBody, key) if !strings.EqualFold(got, want) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid hash"}) } @@ -63,7 +64,7 @@ func HashSHA256(key string) gin.HandlerFunc { } if bw.body.Len() > 0 { - sum := misc.SumSHA256(bw.body.Bytes(), key) + sum := sumSHA256(bw.body.Bytes(), key) c.Header("HashSHA256", sum) } @@ -79,3 +80,8 @@ func HashSHA256(key string) gin.HandlerFunc { } } } + +func sumSHA256(value []byte, key string) string { + sum := sha256.Sum256(append(value, []byte(key)...)) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/adapters/http/ginserver/middlewares/middlewares.go b/internal/infra/http/ginserver/middlewares/middlewares.go similarity index 100% rename from internal/adapters/http/ginserver/middlewares/middlewares.go rename to internal/infra/http/ginserver/middlewares/middlewares.go diff --git a/internal/infra/http/ginserver/middlewares/middlewares_test.go b/internal/infra/http/ginserver/middlewares/middlewares_test.go new file mode 100644 index 0000000..cf77dec --- /dev/null +++ b/internal/infra/http/ginserver/middlewares/middlewares_test.go @@ -0,0 +1,244 @@ +package middlewares + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestHashSHA256_ValidatesRequestAndSetsResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(HashSHA256("secret")) + router.POST("/hash", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + c.Data(http.StatusOK, "text/plain", body) + }) + + plain := []byte("payload") + req := httptest.NewRequest(http.MethodPost, "/hash", bytes.NewReader(plain)) + req.Header.Set("HashSHA256", sumSHA256(plain, "secret")) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if got := rec.Header().Get("HashSHA256"); got == "" { + t.Fatal("expected response HashSHA256 header") + } + if !bytes.Equal(rec.Body.Bytes(), plain) { + t.Fatalf("body=%q want %q", rec.Body.Bytes(), plain) + } +} + +func TestHashSHA256_InvalidRejects(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(HashSHA256("secret")) + router.POST("/hash", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodPost, "/hash", strings.NewReader("payload")) + req.Header.Set("HashSHA256", "bad") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status=%d want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHashSHA256_EmptyKey_IsNoop(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(HashSHA256(" ")) + router.POST("/noop", func(c *gin.Context) { + c.Data(http.StatusOK, "text/plain", []byte("ok")) + }) + + req := httptest.NewRequest(http.MethodPost, "/noop", strings.NewReader("payload")) + req.Header.Set("HashSHA256", "bad") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if got := rec.Header().Get("HashSHA256"); got != "" { + t.Fatalf("unexpected HashSHA256 header: %q", got) + } +} + +func TestHashSHA256_EmptyBody_SkipsValidation(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(HashSHA256("secret")) + router.POST("/empty", func(c *gin.Context) { + c.Data(http.StatusOK, "text/plain", []byte("ok")) + }) + + req := httptest.NewRequest(http.MethodPost, "/empty", http.NoBody) + req.Header.Set("HashSHA256", "bad") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } +} + +func TestGzipRequest_Decompresses(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(GzipRequest()) + router.POST("/gzip", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + c.Data(http.StatusOK, "text/plain", body) + }) + + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + if _, err := zw.Write([]byte("hello")); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/gzip", bytes.NewReader(buf.Bytes())) + req.Header.Set("Content-Encoding", "gzip") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if got := rec.Body.String(); got != "hello" { + t.Fatalf("body=%q want %q", got, "hello") + } +} + +func TestGzipRequest_BadGzip(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(GzipRequest()) + router.POST("/gzip", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodPost, "/gzip", strings.NewReader("not-gzip")) + req.Header.Set("Content-Encoding", "gzip") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status=%d want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestGzipResponse_CompressesJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(GzipResponse()) + router.GET("/gzip", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(`{"ok":true}`)) + }) + + req := httptest.NewRequest(http.MethodGet, "/gzip", nil) + req.Header.Set("Accept-Encoding", "gzip") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if enc := rec.Header().Get("Content-Encoding"); enc != "gzip" { + t.Fatalf("encoding=%q want gzip", enc) + } + + gr, err := gzip.NewReader(bytes.NewReader(rec.Body.Bytes())) + if err != nil { + t.Fatalf("gzip reader: %v", err) + } + defer func() { _ = gr.Close() }() + out, err := io.ReadAll(gr) + if err != nil { + t.Fatalf("read gz body: %v", err) + } + if string(out) != `{"ok":true}` { + t.Fatalf("body=%q want %q", out, `{"ok":true}`) + } +} + +func TestGzipResponse_SkipsWhenNoGzipAccept(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(GzipResponse()) + router.GET("/plain", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(`{"ok":true}`)) + }) + + req := httptest.NewRequest(http.MethodGet, "/plain", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if enc := rec.Header().Get("Content-Encoding"); enc != "" { + t.Fatalf("expected no Content-Encoding, got %q", enc) + } +} + +func TestGzipResponse_SkipsNonCompressibleType(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(GzipResponse()) + router.GET("/bin", func(c *gin.Context) { + c.Data(http.StatusOK, "application/octet-stream", []byte{0x01, 0x02}) + }) + + req := httptest.NewRequest(http.MethodGet, "/bin", nil) + req.Header.Set("Accept-Encoding", "gzip") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if enc := rec.Header().Get("Content-Encoding"); enc != "" { + t.Fatalf("expected no Content-Encoding, got %q", enc) + } +} diff --git a/internal/adapters/http/ginserver/middlewares/zap_logger.go b/internal/infra/http/ginserver/middlewares/zap_logger.go similarity index 100% rename from internal/adapters/http/ginserver/middlewares/zap_logger.go rename to internal/infra/http/ginserver/middlewares/zap_logger.go diff --git a/internal/adapters/http/ginserver/router.go b/internal/infra/http/ginserver/router.go similarity index 100% rename from internal/adapters/http/ginserver/router.go rename to internal/infra/http/ginserver/router.go diff --git a/internal/adapters/persistence/file/file.go b/internal/infra/persistence/file/file.go similarity index 100% rename from internal/adapters/persistence/file/file.go rename to internal/infra/persistence/file/file.go diff --git a/internal/adapters/persistence/file/file_test.go b/internal/infra/persistence/file/file_test.go similarity index 97% rename from internal/adapters/persistence/file/file_test.go rename to internal/infra/persistence/file/file_test.go index 41ffb35..c987644 100644 --- a/internal/adapters/persistence/file/file_test.go +++ b/internal/infra/persistence/file/file_test.go @@ -6,7 +6,7 @@ import ( "path/filepath" "testing" - "github.com/vshulcz/Golectra/internal/adapters/repository/memory" + "github.com/vshulcz/Golectra/internal/infra/repository/memory" ) func mustSetGauge(t *testing.T, repo *memory.Repo, name string, value float64) { diff --git a/internal/adapters/publisher/httpjson/client.go b/internal/infra/publisher/httpjson/client.go similarity index 86% rename from internal/adapters/publisher/httpjson/client.go rename to internal/infra/publisher/httpjson/client.go index 9bd1635..dcea20e 100644 --- a/internal/adapters/publisher/httpjson/client.go +++ b/internal/infra/publisher/httpjson/client.go @@ -17,15 +17,16 @@ import ( "time" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/misc" + "github.com/vshulcz/Golectra/internal/infra/retry" "github.com/vshulcz/Golectra/internal/ports" ) // Client publishes metrics to the server using gzipped JSON requests. type Client struct { - key string - base *url.URL - hc *http.Client + key string + base *url.URL + hc *http.Client + encrypter ports.PayloadEncrypter } var _ ports.Publisher = (*Client)(nil) @@ -44,7 +45,7 @@ var ( ) // New normalizes the base address, configures the HTTP client, and returns a Client instance. -func New(serverAddr string, hc *http.Client, key string) (*Client, error) { +func New(serverAddr string, hc *http.Client, key string, encrypter ports.PayloadEncrypter) (*Client, error) { if hc == nil { hc = &http.Client{Timeout: 10 * time.Second} } @@ -52,7 +53,7 @@ func New(serverAddr string, hc *http.Client, key string) (*Client, error) { if err != nil { return nil, err } - return &Client{base: u, hc: hc, key: strings.TrimSpace(key)}, nil + return &Client{base: u, hc: hc, key: strings.TrimSpace(key), encrypter: encrypter}, nil } func normalizeBase(s string) string { @@ -89,7 +90,7 @@ func (c *Client) doGzJSON(ctx context.Context, path string, payload any) (retErr var hashHeader string if c.key != "" { - hashHeader = misc.SumSHA256(plain, c.key) + hashHeader = sumSHA256(plain, c.key) } gzPayload, err := gzipBytes(plain) @@ -97,10 +98,19 @@ func (c *Client) doGzJSON(ctx context.Context, path string, payload any) (retErr return err } defer gzPayload.Release() - gzBody := gzPayload.Bytes() + body := gzPayload.Bytes() + encrypted := false + if c.encrypter != nil { + enc, err := c.encrypter.Encrypt(body) + if err != nil { + return err + } + body = enc + encrypted = true + } resp, err := c.sendWithRetry(ctx, func() (*http.Request, error) { - return c.newGzJSONRequest(ctx, path, gzBody, hashHeader) + return c.newGzJSONRequest(ctx, path, body, hashHeader, encrypted) }) if err != nil { return err @@ -212,7 +222,7 @@ func gzipBytes(src []byte) (*compressedPayload, error) { return &compressedPayload{buf: buf}, nil } -func (c *Client) newGzJSONRequest(ctx context.Context, path string, body []byte, hashHeader string) (*http.Request, error) { +func (c *Client) newGzJSONRequest(ctx context.Context, path string, body []byte, hashHeader string, encrypted bool) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint(path), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("new request: %w", err) @@ -224,6 +234,9 @@ func (c *Client) newGzJSONRequest(ctx context.Context, path string, body []byte, if hashHeader != "" { req.Header.Set("HashSHA256", hashHeader) } + if encrypted { + req.Header.Set(c.encrypter.HeaderKey(), c.encrypter.HeaderValue()) + } return req, nil } @@ -239,7 +252,7 @@ func (c *Client) sendWithRetry(ctx context.Context, mkReq func() (*http.Request, resp = r return err } - if err := misc.Retry(ctx, misc.DefaultBackoff, isRetryableHTTP, op); err != nil { + if err := retry.Retry(ctx, retry.DefaultBackoff, isRetryableHTTP, op); err != nil { return nil, fmt.Errorf("http do: %w", err) } return resp, nil diff --git a/internal/adapters/publisher/httpjson/client_benchmark_test.go b/internal/infra/publisher/httpjson/client_benchmark_test.go similarity index 95% rename from internal/adapters/publisher/httpjson/client_benchmark_test.go rename to internal/infra/publisher/httpjson/client_benchmark_test.go index c04a474..b7f4c87 100644 --- a/internal/adapters/publisher/httpjson/client_benchmark_test.go +++ b/internal/infra/publisher/httpjson/client_benchmark_test.go @@ -34,7 +34,7 @@ func BenchmarkClientSendBatch(b *testing.B) { })) b.Cleanup(srv.Close) - client, err := New(srv.URL, srv.Client(), "bench-secret") + client, err := New(srv.URL, srv.Client(), "bench-secret", nil) if err != nil { b.Fatalf("new client: %v", err) } diff --git a/internal/adapters/publisher/httpjson/client_test.go b/internal/infra/publisher/httpjson/client_test.go similarity index 82% rename from internal/adapters/publisher/httpjson/client_test.go rename to internal/infra/publisher/httpjson/client_test.go index 54d649c..8fbc77a 100644 --- a/internal/adapters/publisher/httpjson/client_test.go +++ b/internal/infra/publisher/httpjson/client_test.go @@ -4,6 +4,8 @@ import ( "bytes" "compress/gzip" "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "errors" "fmt" @@ -19,12 +21,15 @@ import ( "time" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/misc" + "github.com/vshulcz/Golectra/internal/infra/crypto/rsaenvelope" + "github.com/vshulcz/Golectra/internal/infra/retry" ) const ( - updatePath = "/update" - updatesPath = "/updates" + updatePath = "/update" + updatesPath = "/updates" + allocMetricID = "Alloc" + gaugeMetricType = "gauge" ) func mustWrite(t *testing.T, w io.Writer, data []byte) { @@ -70,7 +75,7 @@ func TestNew_NormalizeBaseAndTimeout(t *testing.T) { if !tc.nilHC { hc = &http.Client{} } - c, err := New(tc.addr, hc, "") + c, err := New(tc.addr, hc, "", nil) if err != nil { t.Fatalf("New error: %v", err) } @@ -104,14 +109,14 @@ func Test_normalizeBase(t *testing.T) { } func TestNew_InvalidURL(t *testing.T) { - _, err := New("http://%zz", nil, "") + _, err := New("http://%zz", nil, "", nil) if err == nil { t.Fatal("expected error for invalid URL") } } func TestClient_JoinPath(t *testing.T) { - c, err := New("http://x:1/base", nil, "") + c, err := New("http://x:1/base", nil, "", nil) if err != nil { t.Fatal(err) } @@ -119,7 +124,7 @@ func TestClient_JoinPath(t *testing.T) { t.Fatalf("endpoint=%q want %q", got, "http://x:1/base/update") } - c2, _ := New("http://x:1/base/", nil, "") + c2, _ := New("http://x:1/base/", nil, "", nil) if got := c2.endpoint(updatePath); got != "http://x:1/base/update" { t.Fatalf("endpoint=%q want %q", got, "http://x:1/base/update") } @@ -215,13 +220,13 @@ func TestSendOne_VariousResponses(t *testing.T) { })) defer srv.Close() - c, err := New(srv.URL, &http.Client{Timeout: 2 * time.Second}, "") + c, err := New(srv.URL, &http.Client{Timeout: 2 * time.Second}, "", nil) if err != nil { t.Fatal(err) } val := 123.45 - err = c.SendOne(context.TODO(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}) + err = c.SendOne(context.TODO(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}) if tt.wantErr == "" && err != nil { t.Fatalf("SendOne error: %v", err) } @@ -239,7 +244,7 @@ func TestSendOne_VariousResponses(t *testing.T) { t.Fatalf("path=%q want %s", got.path, updatePath) } - if got.metric.ID != "Alloc" || got.metric.MType != "gauge" { + if got.metric.ID != allocMetricID || got.metric.MType != gaugeMetricType { t.Fatalf("metric=%+v want id=Alloc type=gauge", got.metric) } if got.metric.Value == nil || *got.metric.Value != 123.45 { @@ -289,7 +294,7 @@ func TestSendOne_CounterPayloadAndHeaders(t *testing.T) { })) defer srv.Close() - c, err := New(srv.URL, nil, "") + c, err := New(srv.URL, nil, "", nil) if err != nil { t.Fatal(err) } @@ -400,7 +405,7 @@ func TestSendBatch_VariousResponses(t *testing.T) { })) defer srv.Close() - c, err := New(srv.URL, &http.Client{Timeout: 2 * time.Second}, "") + c, err := New(srv.URL, &http.Client{Timeout: 2 * time.Second}, "", nil) if err != nil { t.Fatal(err) } @@ -408,7 +413,7 @@ func TestSendBatch_VariousResponses(t *testing.T) { val := 1.23 delta := int64(7) err = c.SendBatch(context.TODO(), []domain.Metrics{ - {ID: "Alloc", MType: "gauge", Value: &val}, + {ID: allocMetricID, MType: gaugeMetricType, Value: &val}, {ID: "PollCount", MType: "counter", Delta: &delta}, }) @@ -435,8 +440,8 @@ func TestSendBatch_VariousResponses(t *testing.T) { var seenGauge, seenCounter bool for _, m := range got.metrics { switch m.MType { - case "gauge": - if m.ID != "Alloc" { + case gaugeMetricType: + if m.ID != allocMetricID { t.Fatalf("gauge id=%q want Alloc", m.ID) } if m.Value == nil || *m.Value != 1.23 { @@ -476,7 +481,7 @@ func (panicRT) RoundTrip(*http.Request) (*http.Response, error) { func TestSendBatch_EmptyBatchIsNoop(t *testing.T) { hc := &http.Client{Transport: panicRT{}, Timeout: time.Second} - c, err := New("http://example", hc, "") + c, err := New("http://example", hc, "", nil) if err != nil { t.Fatal(err) } @@ -562,9 +567,9 @@ func Test_isRetryableHTTP(t *testing.T) { } func TestSendOne_RetryOnNetworkErrors(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -580,13 +585,13 @@ func TestSendOne_RetryOnNetworkErrors(t *testing.T) { }, } hc := &http.Client{Transport: rt, Timeout: 2 * time.Second} - c, err := New("http://example", hc, "") + c, err := New("http://example", hc, "", nil) if err != nil { t.Fatal(err) } val := 42.0 - if err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}); err != nil { + if err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}); err != nil { t.Fatalf("SendOne error: %v", err) } if got := rt.Calls(); got != 3 { @@ -595,9 +600,9 @@ func TestSendOne_RetryOnNetworkErrors(t *testing.T) { } func TestSendOne_RetryExhausted(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -607,10 +612,10 @@ func TestSendOne_RetryExhausted(t *testing.T) { }, } hc := &http.Client{Transport: rt, Timeout: 2 * time.Second} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) val := 1.0 - err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}) + err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}) if err == nil || !strings.Contains(err.Error(), "http do:") { t.Fatalf("want http do error, got: %v", err) } @@ -620,9 +625,9 @@ func TestSendOne_RetryExhausted(t *testing.T) { } func TestSendOne_NoRetry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -630,10 +635,10 @@ func TestSendOne_NoRetry(t *testing.T) { }, } hc := &http.Client{Transport: rt} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) val := 7.0 - err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}) + err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}) if err == nil || !strings.Contains(err.Error(), "http do:") { t.Fatalf("want http do error, got: %v", err) } @@ -643,9 +648,9 @@ func TestSendOne_NoRetry(t *testing.T) { } func TestSendOne_NoRetryOn400(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -655,10 +660,10 @@ func TestSendOne_NoRetryOn400(t *testing.T) { }, } hc := &http.Client{Transport: rt} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) val := 3.14 - err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}) + err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}) if err == nil || !strings.Contains(err.Error(), "400") { t.Fatalf("want 400 error, got: %v", err) } @@ -668,9 +673,9 @@ func TestSendOne_NoRetryOn400(t *testing.T) { } func TestSendBatch_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -679,12 +684,12 @@ func TestSendBatch_Retry(t *testing.T) { }, } hc := &http.Client{Transport: rt} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) val := 1.23 delta := int64(7) err := c.SendBatch(context.Background(), []domain.Metrics{ - {ID: "Alloc", MType: "gauge", Value: &val}, + {ID: allocMetricID, MType: gaugeMetricType, Value: &val}, {ID: "PollCount", MType: "counter", Delta: &delta}, }) if err != nil { @@ -696,9 +701,9 @@ func TestSendBatch_Retry(t *testing.T) { } func TestSendOne_ContextCancel(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{50 * time.Millisecond, 50 * time.Millisecond, 50 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{50 * time.Millisecond, 50 * time.Millisecond, 50 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() rt := &scriptedRT{ steps: []func(*http.Request) (*http.Response, error){ @@ -708,13 +713,13 @@ func TestSendOne_ContextCancel(t *testing.T) { }, } hc := &http.Client{Transport: rt} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) defer cancel() val := 10.0 - err := c.SendOne(ctx, domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}) + err := c.SendOne(ctx, domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}) if err == nil || (!strings.Contains(err.Error(), "http do:") && !errors.Is(err, context.DeadlineExceeded)) { t.Fatalf("want context-related error, got: %v", err) } @@ -738,10 +743,10 @@ func TestSendOne_ServerGzipResponse(t *testing.T) { }, } hc := &http.Client{Transport: rt} - c, _ := New("http://example", hc, "") + c, _ := New("http://example", hc, "", nil) val := 1.0 - if err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}); err != nil { + if err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}); err != nil { t.Fatalf("SendOne error: %v", err) } } @@ -764,13 +769,13 @@ func TestSendOne_NoHashHeader(t *testing.T) { })) defer srv.Close() - c, err := New(srv.URL, nil, "") + c, err := New(srv.URL, nil, "", nil) if err != nil { t.Fatal(err) } val := 3.14 - if err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}); err != nil { + if err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}); err != nil { t.Fatalf("SendOne error: %v", err) } } @@ -799,7 +804,7 @@ func TestSendOne_HashHeader_Present(t *testing.T) { }() raw := mustReadAll(t, gr) - expected := misc.SumSHA256(raw, key) + expected := sumSHA256(raw, key) if h != expected { t.Fatalf("HashSHA256=%q want %q", h, expected) } @@ -808,13 +813,13 @@ func TestSendOne_HashHeader_Present(t *testing.T) { })) defer srv.Close() - c, err := New(srv.URL, nil, key) + c, err := New(srv.URL, nil, key, nil) if err != nil { t.Fatal(err) } val := 2.71 - if err := c.SendOne(context.Background(), domain.Metrics{ID: "Alloc", MType: "gauge", Value: &val}); err != nil { + if err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}); err != nil { t.Fatalf("SendOne error: %v", err) } } @@ -843,7 +848,7 @@ func TestSendBatch_HashHeader_Present(t *testing.T) { }() raw := mustReadAll(t, gr) - expected := misc.SumSHA256(raw, key) + expected := sumSHA256(raw, key) if h != expected { t.Fatalf("HashSHA256=%q want %q", h, expected) } @@ -861,6 +866,57 @@ func TestSendBatch_HashHeader_Present(t *testing.T) { defer srv.Close() } +func TestSendOne_EncryptedPayload(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + encrypter := rsaenvelope.NewEncrypter(&priv.PublicKey) + decrypter := rsaenvelope.NewDecrypter(priv) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get(encrypter.HeaderKey()); got != encrypter.HeaderValue() { + t.Fatalf("header=%q want %q", got, encrypter.HeaderValue()) + } + + raw := mustReadAll(t, r.Body) + plain, err := decrypter.Decrypt(raw) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + + gr, err := gzip.NewReader(bytes.NewReader(plain)) + if err != nil { + t.Fatalf("bad gzip: %v", err) + } + defer func() { + mustClose(t, gr) + }() + payload := mustReadAll(t, gr) + + var m domain.Metrics + if err := json.Unmarshal(payload, &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if m.ID != allocMetricID || m.MType != gaugeMetricType || m.Value == nil { + t.Fatalf("unexpected payload: %+v", m) + } + + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c, err := New(srv.URL, nil, "", encrypter) + if err != nil { + t.Fatal(err) + } + + val := 1.23 + if err := c.SendOne(context.Background(), domain.Metrics{ID: allocMetricID, MType: gaugeMetricType, Value: &val}); err != nil { + t.Fatalf("SendOne error: %v", err) + } +} + type nopWriteCloser struct{ *strings.Builder } func (n *nopWriteCloser) Write(p []byte) (int, error) { return n.Builder.Write(p) } diff --git a/internal/infra/publisher/httpjson/hash.go b/internal/infra/publisher/httpjson/hash.go new file mode 100644 index 0000000..223fbba --- /dev/null +++ b/internal/infra/publisher/httpjson/hash.go @@ -0,0 +1,11 @@ +package httpjson + +import ( + "crypto/sha256" + "encoding/hex" +) + +func sumSHA256(value []byte, key string) string { + sum := sha256.Sum256(append(value, []byte(key)...)) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/adapters/publisher/httpjson/httpjson.go b/internal/infra/publisher/httpjson/httpjson.go similarity index 100% rename from internal/adapters/publisher/httpjson/httpjson.go rename to internal/infra/publisher/httpjson/httpjson.go diff --git a/internal/adapters/repository/memory/memory.go b/internal/infra/repository/memory/memory.go similarity index 100% rename from internal/adapters/repository/memory/memory.go rename to internal/infra/repository/memory/memory.go diff --git a/internal/adapters/repository/memory/memory_test.go b/internal/infra/repository/memory/memory_test.go similarity index 100% rename from internal/adapters/repository/memory/memory_test.go rename to internal/infra/repository/memory/memory_test.go diff --git a/internal/adapters/repository/postgres/migrate.go b/internal/infra/repository/postgres/migrate.go similarity index 100% rename from internal/adapters/repository/postgres/migrate.go rename to internal/infra/repository/postgres/migrate.go diff --git a/internal/adapters/repository/postgres/migrate_test.go b/internal/infra/repository/postgres/migrate_test.go similarity index 100% rename from internal/adapters/repository/postgres/migrate_test.go rename to internal/infra/repository/postgres/migrate_test.go diff --git a/internal/adapters/repository/postgres/migrations/0001_init.sql b/internal/infra/repository/postgres/migrations/0001_init.sql similarity index 100% rename from internal/adapters/repository/postgres/migrations/0001_init.sql rename to internal/infra/repository/postgres/migrations/0001_init.sql diff --git a/internal/adapters/repository/postgres/postgres.go b/internal/infra/repository/postgres/postgres.go similarity index 92% rename from internal/adapters/repository/postgres/postgres.go rename to internal/infra/repository/postgres/postgres.go index 193bd61..6b06754 100644 --- a/internal/adapters/repository/postgres/postgres.go +++ b/internal/infra/repository/postgres/postgres.go @@ -13,7 +13,7 @@ import ( "github.com/jackc/pgerrcode" "github.com/lib/pq" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/misc" + "github.com/vshulcz/Golectra/internal/infra/retry" "github.com/vshulcz/Golectra/internal/ports" ) @@ -55,7 +55,7 @@ func (r *Repo) GetGauge(ctx context.Context, n string) (float64, error) { v = sql.NullFloat64{} return r.db.QueryRowContext(ctx, q, n, string(domain.Gauge)).Scan(&v) } - if err := misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op); err != nil { + if err := retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op); err != nil { if errors.Is(err, sql.ErrNoRows) { return 0, domain.ErrNotFound } @@ -75,7 +75,7 @@ func (r *Repo) GetCounter(ctx context.Context, n string) (int64, error) { d = sql.NullInt64{} return r.db.QueryRowContext(ctx, q, n, string(domain.Counter)).Scan(&d) } - if err := misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op); err != nil { + if err := retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op); err != nil { if errors.Is(err, sql.ErrNoRows) { return 0, domain.ErrNotFound } @@ -98,7 +98,7 @@ DO UPDATE SET mtype=$2, value=EXCLUDED.value, delta=NULL, updated_at=now();` _, err := r.db.ExecContext(ctx, q, n, string(domain.Gauge), v) return err } - return misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op) + return retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op) } // AddCounter increments (or creates) the named counter. @@ -112,7 +112,7 @@ DO UPDATE SET mtype=$2, value=NULL, delta=COALESCE(metrics.delta,0)+EXCLUDED.del _, err := r.db.ExecContext(ctx, q, n, string(domain.Counter), d) return err } - return misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op) + return retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op) } // UpdateMany atomically applies a batch of metrics inside a transaction. @@ -165,7 +165,7 @@ DO UPDATE SET mtype=$2, value=NULL, delta=COALESCE(metrics.delta,0)+EXCLUDED.del } return nil } - return misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, attempt) + return retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, attempt) } // Snapshot loads all stored metrics and returns them grouped by type. @@ -211,7 +211,7 @@ func (r *Repo) Snapshot(ctx context.Context) (domain.Snapshot, error) { resultC = c return nil } - if err := misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op); err != nil { + if err := retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op); err != nil { return domain.Snapshot{Gauges: resultG, Counters: resultC}, err } return domain.Snapshot{Gauges: resultG, Counters: resultC}, nil @@ -227,7 +227,7 @@ func (r *Repo) Ping(ctx context.Context) error { op := func() error { return r.db.PingContext(ctx) } - return misc.Retry(ctx, misc.DefaultBackoff, isRetryablePG, op) + return retry.Retry(ctx, retry.DefaultBackoff, isRetryablePG, op) } // IsRetryable reports whether the error should trigger a retry according to Postgres semantics. diff --git a/internal/adapters/repository/postgres/postgres_test.go b/internal/infra/repository/postgres/postgres_test.go similarity index 93% rename from internal/adapters/repository/postgres/postgres_test.go rename to internal/infra/repository/postgres/postgres_test.go index c69fe48..142bc23 100644 --- a/internal/adapters/repository/postgres/postgres_test.go +++ b/internal/infra/repository/postgres/postgres_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgerrcode" "github.com/lib/pq" "github.com/vshulcz/Golectra/internal/domain" - "github.com/vshulcz/Golectra/internal/misc" + "github.com/vshulcz/Golectra/internal/infra/retry" ) func TestRepo_GetGauge(t *testing.T) { @@ -429,9 +429,9 @@ func Test_isRetryablePG(t *testing.T) { } func TestRepo_GetGauge_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -453,9 +453,9 @@ func TestRepo_GetGauge_Retry(t *testing.T) { } func TestRepo_GetCounter_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -491,9 +491,9 @@ func TestRepo_GetGauge_NoRetry(t *testing.T) { } func TestRepo_SetGauge_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -517,9 +517,9 @@ DO UPDATE SET mtype=$2, value=EXCLUDED.value, delta=NULL, updated_at=now();` } func TestRepo_AddCounter_Retr(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -539,9 +539,9 @@ DO UPDATE SET mtype=$2, value=NULL, delta=COALESCE(metrics.delta,0)+EXCLUDED.del } func TestRepo_UpdateMany_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -577,9 +577,9 @@ DO UPDATE SET mtype=$2, value=NULL, delta=COALESCE(metrics.delta,0)+EXCLUDED.del } func TestRepo_Snapshot_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() @@ -602,9 +602,9 @@ func TestRepo_Snapshot_Retry(t *testing.T) { } func TestRepo_Ping_Retry(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{1 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{1 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMockWithPing(t) defer done() @@ -618,9 +618,9 @@ func TestRepo_Ping_Retry(t *testing.T) { } func TestRepo_GetGauge_ContextCancel(t *testing.T) { - orig := misc.DefaultBackoff - misc.DefaultBackoff = []time.Duration{50 * time.Millisecond, 50 * time.Millisecond} - defer func() { misc.DefaultBackoff = orig }() + orig := retry.DefaultBackoff + retry.DefaultBackoff = []time.Duration{50 * time.Millisecond, 50 * time.Millisecond} + defer func() { retry.DefaultBackoff = orig }() _, mock, st, done := newMock(t) defer done() diff --git a/internal/misc/retry.go b/internal/infra/retry/retry.go similarity index 98% rename from internal/misc/retry.go rename to internal/infra/retry/retry.go index 16509ed..75b5dfd 100644 --- a/internal/misc/retry.go +++ b/internal/infra/retry/retry.go @@ -1,4 +1,4 @@ -package misc +package retry import ( "context" diff --git a/internal/infra/retry/retry_test.go b/internal/infra/retry/retry_test.go new file mode 100644 index 0000000..993fa64 --- /dev/null +++ b/internal/infra/retry/retry_test.go @@ -0,0 +1,88 @@ +package retry + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +func TestRetry_SuccessAfterRetry(t *testing.T) { + var calls int32 + wantErr := errors.New("retry") + op := func() error { + if atomic.AddInt32(&calls, 1) == 1 { + return wantErr + } + return nil + } + + err := Retry(context.Background(), []time.Duration{1 * time.Millisecond}, func(err error) bool { + return errors.Is(err, wantErr) + }, op) + if err != nil { + t.Fatalf("Retry error: %v", err) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Fatalf("calls=%d want 2", got) + } +} + +func TestRetry_NonRetryableStops(t *testing.T) { + var calls int32 + wantErr := errors.New("nope") + op := func() error { + atomic.AddInt32(&calls, 1) + return wantErr + } + + err := Retry(context.Background(), []time.Duration{1 * time.Millisecond}, func(error) bool { return false }, op) + if !errors.Is(err, wantErr) { + t.Fatalf("Retry error=%v want %v", err, wantErr) + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("calls=%d want 1", got) + } +} + +func TestRetry_ContextCancel(t *testing.T) { + var calls int32 + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := Retry(ctx, []time.Duration{1 * time.Millisecond}, func(error) bool { return true }, func() error { + atomic.AddInt32(&calls, 1) + return errors.New("retry") + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Retry error=%v want context canceled", err) + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("calls=%d want 1", got) + } +} + +func TestRetry_NoDelaysStops(t *testing.T) { + wantErr := errors.New("no retry") + err := Retry(context.Background(), nil, func(error) bool { return true }, func() error { + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("Retry error=%v want %v", err, wantErr) + } +} + +func TestRetry_ContextCancelsDuringDelay(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + err := Retry(ctx, []time.Duration{50 * time.Millisecond}, func(error) bool { return true }, func() error { + cancel() + close(done) + return errors.New("retry") + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Retry error=%v want context canceled", err) + } +} diff --git a/internal/misc/env.go b/internal/misc/env.go deleted file mode 100644 index aa6d618..0000000 --- a/internal/misc/env.go +++ /dev/null @@ -1,53 +0,0 @@ -package misc - -import ( - "os" - "strconv" - "strings" - "time" -) - -// Getenv returns the environment value for key or the default when empty. -func Getenv(key, def string) string { - if v := os.Getenv(key); v != "" { - return v - } - return def -} - -// GetDuration parses an environment variable as duration in seconds or Go syntax. -func GetDuration(key string, def time.Duration) time.Duration { - v := os.Getenv(key) - if v == "" { - return def - } - if n, err := strconv.ParseInt(v, 10, 64); err == nil { - if n <= 0 { - return 0 - } - return time.Duration(n) * time.Second - } - if d, err := time.ParseDuration(v); err == nil { - if d <= 0 { - return 0 - } - return d - } - return def -} - -// GetBool parses common boolean values (true/false, yes/no, 1/0) from the environment. -func GetBool(key string, def bool) bool { - v := strings.TrimSpace(strings.ToLower(os.Getenv(key))) - if v == "" { - return def - } - switch v { - case "1", "true", "t", "yes", "y": - return true - case "0", "false", "f", "no", "n": - return false - default: - return def - } -} diff --git a/internal/misc/env_test.go b/internal/misc/env_test.go deleted file mode 100644 index c6d7c61..0000000 --- a/internal/misc/env_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package misc - -import ( - "testing" - "time" -) - -func TestGetenv(t *testing.T) { - tests := []struct { - name string - key string - val string - def string - expect string - }{ - {"value present", "X_FOO", "bar", "zzz", "bar"}, - {"value empty -> default", "X_EMPTY", "", "defv", "defv"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.val != "" { - t.Setenv(tt.key, tt.val) - } else { - t.Setenv(tt.key, "") - } - got := Getenv(tt.key, tt.def) - if got != tt.expect { - t.Errorf("Getenv(%s) = %q, want %q", tt.key, got, tt.expect) - } - }) - } -} - -func TestGetDuration(t *testing.T) { - tests := []struct { - name string - val string - def time.Duration - expect time.Duration - }{ - {"valid duration", "5s", 0, 5 * time.Second}, - {"valid ms duration", "250ms", 0, 250 * time.Millisecond}, - {"negative duration", "-5s", 5 * time.Second, 0}, - - {"numeric string", "10", 0, 10 * time.Second}, - {"zero numeric", "0", 5 * time.Second, 0}, - {"negative numeric", "-3", 5 * time.Second, 0}, - - {"bad format -> default", "oops", 3 * time.Second, 3 * time.Second}, - {"empty -> default", "", 7 * time.Second, 7 * time.Second}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Setenv("X_DUR", tt.val) - got := GetDuration("X_DUR", tt.def) - if got != tt.expect { - t.Errorf("val=%q def=%v -> got=%v, want=%v", tt.val, tt.def, got, tt.expect) - } - }) - } -} - -func TestGetBool(t *testing.T) { - trueVals := []string{"1", "true", "t", "yes", "y"} - falseVals := []string{"0", "false", "f", "no", "n"} - - for _, v := range trueVals { - t.Run("true_"+v, func(t *testing.T) { - t.Setenv("X_BOOL", v) - if !GetBool("X_BOOL", false) { - t.Errorf("GetBool(%q) = false, want true", v) - } - }) - } - - for _, v := range falseVals { - t.Run("false_"+v, func(t *testing.T) { - t.Setenv("X_BOOL", v) - if GetBool("X_BOOL", true) { - t.Errorf("GetBool(%q) = true, want false", v) - } - }) - } - - t.Run("empty -> default true", func(t *testing.T) { - t.Setenv("X_BOOL", "") - if !GetBool("X_BOOL", true) { - t.Error("expected default true") - } - }) - - t.Run("empty -> default false", func(t *testing.T) { - t.Setenv("X_BOOL", "") - if GetBool("X_BOOL", false) { - t.Error("expected default false") - } - }) - - t.Run("unknown string -> default", func(t *testing.T) { - t.Setenv("X_BOOL", "maybe") - if !GetBool("X_BOOL", true) { - t.Error("expected fallback to default true") - } - }) -} diff --git a/internal/misc/hash.go b/internal/misc/hash.go deleted file mode 100644 index b1210fd..0000000 --- a/internal/misc/hash.go +++ /dev/null @@ -1,12 +0,0 @@ -package misc - -import ( - "crypto/sha256" - "encoding/hex" -) - -// SumSHA256 returns a hex-encoded SHA256 checksum of value concatenated with key. -func SumSHA256(value []byte, key string) string { - sum := sha256.Sum256(append(value, []byte(key)...)) - return hex.EncodeToString(sum[:]) -} diff --git a/internal/misc/hash_test.go b/internal/misc/hash_test.go deleted file mode 100644 index f310429..0000000 --- a/internal/misc/hash_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package misc - -import ( - "encoding/hex" - "testing" -) - -func TestSumSHA256(t *testing.T) { - tests := []struct { - name string - value []byte - key string - want string - }{ - {"empty both", []byte{}, "", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"}, - {"hello/world", []byte("hello"), "world", "936a185caaa266bb9cbe981e9e05cb78cd732b0b3280eb944412bb6f8f8f07af"}, - {"bytes/key", []byte{0x00, 0x01, 0x02}, "key", "acc3dc23298dcb1aec9b764fbc38f8eaea64040c6cd2c0a7bc958be0cf06b292"}, - {"unicode", []byte("Привет"), "ключ", "68dbb03b4b69fd44385c59eb1b5386ec0b91e225f3fc1e03f1ec6841c53490ec"}, - {"nil value", nil, "", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := SumSHA256(tc.value, tc.key) - if got != tc.want { - t.Fatalf("SumSHA256(%v, %q) = %s; want %s", tc.value, tc.key, got, tc.want) - } - }) - } -} - -func TestSumSHA256_Prop(t *testing.T) { - value := []byte("samevalue") - key := "k1" - got1 := SumSHA256(value, key) - got2 := SumSHA256(value, key) - if got1 != got2 { - t.Fatalf("SumSHA256 not deterministic: %s != %s", got1, got2) - } - - other := SumSHA256(value, "k2") - if got1 == other { - t.Fatalf("different keys produced same sum: %s == %s", got1, other) - } - - decoded, err := hex.DecodeString(got1) - if err != nil { - t.Fatalf("result is not valid hex: %v", err) - } - if len(decoded) != 32 { - t.Fatalf("decoded length = %d, want 32", len(decoded)) - } -} diff --git a/internal/misc/misc.go b/internal/misc/misc.go deleted file mode 100644 index e371c26..0000000 --- a/internal/misc/misc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package misc provides miscellaneous utility functions. -package misc diff --git a/internal/misc/pool.go b/internal/misc/pool.go deleted file mode 100644 index f17347e..0000000 --- a/internal/misc/pool.go +++ /dev/null @@ -1,42 +0,0 @@ -package misc - -import "sync" - -// Resetter is an interface for types that can reset their state. -type Resetter interface { - Reset() -} - -// Pool is a generic object pool for types that implement the Resetter interface. -type Pool[T Resetter] struct { - p sync.Pool -} - -// NewPool creates a new Pool for the specified type T. -func NewPool[T Resetter](newFn func() T) *Pool[T] { - pl := &Pool[T]{} - pl.p.New = func() any { - if newFn != nil { - return newFn() - } - var zero T - return zero - } - return pl -} - -// Get retrieves an object from the pool. -func (pl *Pool[T]) Get() T { - obj := pl.p.Get() - if value, ok := obj.(T); ok { - return value - } - var zero T - return zero -} - -// Put returns an object to the pool after resetting it. -func (pl *Pool[T]) Put(v T) { - v.Reset() - pl.p.Put(v) -} diff --git a/internal/misc/pool_test.go b/internal/misc/pool_test.go deleted file mode 100644 index 1a83bb0..0000000 --- a/internal/misc/pool_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package misc - -import ( - "sync" - "testing" -) - -type mockResetter struct { - resetCalled bool -} - -func (m *mockResetter) Reset() { - m.resetCalled = true -} - -func TestNewPool(t *testing.T) { - pool := NewPool(func() *mockResetter { - return &mockResetter{} - }) - - if pool == nil { - t.Fatal("expected pool to be created, got nil") - } -} - -func TestPoolGet(t *testing.T) { - pool := NewPool(func() *mockResetter { - return &mockResetter{} - }) - - item := pool.Get() - if item == nil { - t.Fatal("expected item to be non-nil, got nil") - } -} - -func TestPoolPut(t *testing.T) { - pool := NewPool(func() *mockResetter { - return &mockResetter{} - }) - - // Put an item into the pool - Reset should be called on it - item := &mockResetter{resetCalled: false} - pool.Put(item) - - // After Put, the item should have Reset called - if !item.resetCalled { - t.Fatal("expected Reset to be called on item when Put, but it wasn't") - } -} - -func TestPoolConcurrency(t *testing.T) { - pool := NewPool(func() *mockResetter { - return &mockResetter{} - }) - - var wg sync.WaitGroup - const numGoroutines = 100 - - wg.Add(numGoroutines) - for range numGoroutines { - go func() { - defer wg.Done() - item := pool.Get() - pool.Put(item) - }() - } - - wg.Wait() -} diff --git a/internal/misc/retry_test.go b/internal/misc/retry_test.go deleted file mode 100644 index ed693cc..0000000 --- a/internal/misc/retry_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package misc - -import ( - "context" - "errors" - "testing" - "time" -) - -var ( - errRetriable = errors.New("retriable") - errPermanent = errors.New("permanent") -) - -func isRetriable(err error) bool { - return errors.Is(err, errRetriable) -} - -func makeOp(steps []error) (func() error, *int) { - attempt := 0 - return func() error { - defer func() { attempt++ }() - idx := attempt - if idx >= len(steps) { - idx = len(steps) - 1 - } - return steps[idx] - }, &attempt -} - -func TestRetry(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - delays []time.Duration - steps []error - timeout time.Duration - cancelBefore bool - wantAttempts int - wantErrCheck func(error) bool - }{ - { - name: "success_immediate", - delays: []time.Duration{10 * time.Millisecond, 10 * time.Millisecond, 10 * time.Millisecond}, - steps: []error{nil}, - wantAttempts: 1, - wantErrCheck: func(err error) bool { return err == nil }, - }, - { - name: "non_retryable_immediate", - delays: []time.Duration{10 * time.Millisecond, 10 * time.Millisecond, 10 * time.Millisecond}, - steps: []error{errPermanent}, - wantAttempts: 1, - wantErrCheck: func(err error) bool { return errors.Is(err, errPermanent) }, - }, - { - name: "success_after_two_retries", - delays: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond, 5 * time.Millisecond}, - steps: []error{errRetriable, errRetriable, nil}, - wantAttempts: 3, - wantErrCheck: func(err error) bool { return err == nil }, - }, - { - name: "exhausted_retries_returns_last_error", - delays: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond, 5 * time.Millisecond}, - steps: []error{errRetriable}, - wantAttempts: 4, - wantErrCheck: func(err error) bool { return errors.Is(err, errRetriable) }, - }, - { - name: "stops_on_permanent_midway", - delays: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond, 5 * time.Millisecond}, - steps: []error{errRetriable, errPermanent, errRetriable}, - wantAttempts: 2, - wantErrCheck: func(err error) bool { return errors.Is(err, errPermanent) }, - }, - { - name: "context_timeout_during_backoff", - delays: []time.Duration{50 * time.Millisecond, 50 * time.Millisecond, 50 * time.Millisecond}, - steps: []error{errRetriable}, - timeout: 10 * time.Millisecond, - wantAttempts: 1, - wantErrCheck: func(err error) bool { return errors.Is(err, context.DeadlineExceeded) }, - }, - { - name: "context_canceled_before_start", - delays: []time.Duration{50 * time.Millisecond, 50 * time.Millisecond, 50 * time.Millisecond}, - steps: []error{errRetriable}, - cancelBefore: true, - wantAttempts: 1, - wantErrCheck: func(err error) bool { return errors.Is(err, context.Canceled) }, - }, - { - name: "no_delays_means_no_retries", - delays: nil, - steps: []error{errRetriable}, - wantAttempts: 1, - wantErrCheck: func(err error) bool { return errors.Is(err, errRetriable) }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - var cancel context.CancelFunc - if tc.timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, tc.timeout) - defer cancel() - } - if tc.cancelBefore { - var c context.CancelFunc - ctx, c = context.WithCancel(ctx) - c() - } - - op, attempts := makeOp(tc.steps) - err := Retry(ctx, tc.delays, isRetriable, op) - - if !tc.wantErrCheck(err) { - t.Fatalf("unexpected error: %v", err) - } - if *attempts != tc.wantAttempts { - t.Fatalf("attempts=%d want %d", *attempts, tc.wantAttempts) - } - }) - } -} diff --git a/internal/ports/audit.go b/internal/ports/audit.go new file mode 100644 index 0000000..c8608d2 --- /dev/null +++ b/internal/ports/audit.go @@ -0,0 +1,12 @@ +package ports + +import ( + "context" + + "github.com/vshulcz/Golectra/internal/domain" +) + +// AuditPublisher publishes audit events to an external sink. +type AuditPublisher interface { + Publish(context.Context, domain.AuditEvent) error +} diff --git a/internal/ports/crypto.go b/internal/ports/crypto.go new file mode 100644 index 0000000..2abe1e1 --- /dev/null +++ b/internal/ports/crypto.go @@ -0,0 +1,15 @@ +package ports + +// PayloadEncrypter encrypts outbound payloads and advertises its header metadata. +type PayloadEncrypter interface { + Encrypt([]byte) ([]byte, error) + HeaderKey() string + HeaderValue() string +} + +// PayloadDecrypter decrypts inbound payloads and advertises its header metadata. +type PayloadDecrypter interface { + Decrypt([]byte) ([]byte, error) + HeaderKey() string + HeaderValue() string +} diff --git a/internal/services/audit/audit.go b/internal/services/audit/audit.go deleted file mode 100644 index 84b6f47..0000000 --- a/internal/services/audit/audit.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package audit provides functionalities for auditing user actions within the application. -package audit diff --git a/internal/services/audit/context.go b/internal/services/audit/context.go deleted file mode 100644 index ccf24ee..0000000 --- a/internal/services/audit/context.go +++ /dev/null @@ -1,21 +0,0 @@ -package audit - -import "context" - -type ctxKey string - -const clientIPKey ctxKey = "audit_client_ip" - -// WithClientIP stores the originating request IP inside the context for later audit fan-out. -func WithClientIP(ctx context.Context, ip string) context.Context { - return context.WithValue(ctx, clientIPKey, ip) -} - -// ClientIPFromContext extracts the stored client IP, returning an empty string when missing. -func ClientIPFromContext(ctx context.Context) string { - if ctx == nil { - return "" - } - v, _ := ctx.Value(clientIPKey).(string) - return v -} diff --git a/internal/services/audit/context_test.go b/internal/services/audit/context_test.go deleted file mode 100644 index b19ff99..0000000 --- a/internal/services/audit/context_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package audit - -import ( - "context" - "testing" -) - -func TestWithClientIP(t *testing.T) { - ctx := context.Background() - ctx = WithClientIP(ctx, "127.0.0.1") - if got := ClientIPFromContext(ctx); got != "127.0.0.1" { - t.Fatalf("ClientIPFromContext returned %q", got) - } - - if got := ClientIPFromContext(context.Background()); got != "" { - t.Fatalf("expected empty ip, got %q", got) - } -} diff --git a/internal/services/audit/event.go b/internal/services/audit/event.go deleted file mode 100644 index 7276a10..0000000 --- a/internal/services/audit/event.go +++ /dev/null @@ -1,8 +0,0 @@ -package audit - -// Event describes which metrics changed, when, and from which IP address. -type Event struct { - Timestamp int64 `json:"ts"` - Metrics []string `json:"metrics"` - IPAddress string `json:"ip_address"` -} diff --git a/internal/services/audit/subject.go b/internal/services/audit/subject.go deleted file mode 100644 index 76014b9..0000000 --- a/internal/services/audit/subject.go +++ /dev/null @@ -1,20 +0,0 @@ -package audit - -import "github.com/vshulcz/Golectra/pkg/observer" - -// Observer receives audit events. -type Observer = observer.Observer[Event] - -// ObserverFunc adapts a plain function to the Observer interface. -type ObserverFunc = observer.ObserverFunc[Event] - -// Publisher broadcasts audit events. -type Publisher = observer.Publisher[Event] - -// Subject fans out events to registered observers. -type Subject = observer.Subject[Event] - -// NewSubject creates a subject optionally pre-populated with observers. -func NewSubject(observers ...Observer) *Subject { - return observer.NewSubject[Event](observers...) -} diff --git a/internal/services/audit/subject_test.go b/internal/services/audit/subject_test.go deleted file mode 100644 index bd000cd..0000000 --- a/internal/services/audit/subject_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package audit - -import ( - "context" - "errors" - "sync" - "testing" -) - -func TestSubject_Publish_NotifiesAll(t *testing.T) { - s := NewSubject() - var mu sync.Mutex - var called []Event - - s.Attach(ObserverFunc(func(_ context.Context, evt Event) error { - mu.Lock() - defer mu.Unlock() - called = append(called, evt) - return nil - })) - - evt := Event{Timestamp: 1, Metrics: []string{"Alloc"}, IPAddress: "1.1.1.1"} - s.Publish(context.Background(), evt) - - mu.Lock() - defer mu.Unlock() - if len(called) != 1 { - t.Fatalf("expected 1 call, got %d", len(called)) - } - if called[0].IPAddress != evt.IPAddress { - t.Fatalf("event mismatch: %+v", called[0]) - } -} - -func TestSubject_ErrorHandler(t *testing.T) { - s := NewSubject() - var mu sync.Mutex - var errs []error - - s.SetErrorHandler(func(err error) { - mu.Lock() - defer mu.Unlock() - errs = append(errs, err) - }) - - s.Attach(ObserverFunc(func(_ context.Context, _ Event) error { - return errors.New("boom") - })) - - s.Publish(context.Background(), Event{}) - - mu.Lock() - defer mu.Unlock() - if len(errs) != 1 || errs[0].Error() != "boom" { - t.Fatalf("expected error handler to capture boom, got %+v", errs) - } -} diff --git a/profiles/README.md b/profiles/README.md index 47b9443..08feada 100644 --- a/profiles/README.md +++ b/profiles/README.md @@ -6,7 +6,7 @@ This folder documents how to re-run every performance benchmark and interpret th 1. Build the standalone benchmark once (re-usable between runs): ```bash - go test -c ./internal/services/metrics -o metrics_bench.test + go test -c ./internal/application/metrics -o metrics_bench.test ``` 2. Capture a baseline profile (before changes): ```bash @@ -28,8 +28,8 @@ Latest diff (after trimming IDs once + deduping names in-place): ``` Showing nodes accounting for 389.16MB, 13.35% of 2915.98MB total flat flat% sum% cum cum% - -741.44MB 25.43% 13.35% -741.44MB 25.43% github.com/vshulcz/Golectra/internal/services/metrics.metricNames - 389.16MB 13.35% 13.35% 389.16MB 13.35% github.com/vshulcz/Golectra/internal/services/metrics.(*Service).UpsertBatch + -741.44MB 25.43% 13.35% -741.44MB 25.43% github.com/vshulcz/Golectra/internal/application/metrics.metricNames + 389.16MB 13.35% 13.35% 389.16MB 13.35% github.com/vshulcz/Golectra/internal/application/metrics.(*Service).UpsertBatch ``` Resulting benchmark: `BenchmarkServiceUpsertBatch-8 257097 4506 ns/op 13184 B/op 2 allocs/op`. @@ -37,7 +37,7 @@ Resulting benchmark: `BenchmarkServiceUpsertBatch-8 257097 4506 ns/op 13184 B/op ### HTTP API – `/updates` JSON handler ``` -go test ./internal/adapters/http/ginserver \ +go test ./internal/infra/http/ginserver \ -run ^$ \ -bench BenchmarkHandlerUpdateMetricsBatchJSON \ -benchmem \ @@ -55,7 +55,7 @@ Top allocators: encoding/json.Decoder.refill (53%), Service.UpsertBatch (25%). ### HTTP publisher – gzip JSON client ``` -go test ./internal/adapters/publisher/httpjson \ +go test ./internal/infra/publisher/httpjson \ -run ^$ \ -bench BenchmarkClientSendBatch \ -benchmem \ @@ -73,7 +73,7 @@ compress/flate.NewWriter now accounts for ~37% of alloc_space (was ~74%). ### Agent service – `reportOnce` ``` -go test ./internal/services/agent \ +go test ./internal/application/agent \ -run ^$ \ -bench BenchmarkAgentReportOnce \ -benchmem \ @@ -86,4 +86,4 @@ Reusing the batch slice instead of allocating per tick reduced the footprint to ``` BenchmarkAgentReportOnce-8 211834 5672 ns/op 3200 B/op 400 allocs/op alloc_space is now entirely attributed to Service.buildBatch (expected, since it controls the reusable buffer). -``` \ No newline at end of file +``` From bcc2e50f6e5929f3e539120811e4b45f4f7660a0 Mon Sep 17 00:00:00 2001 From: vshulcz Date: Fri, 16 Jan 2026 14:30:48 +0300 Subject: [PATCH 2/2] fix: ci issues --- cmd/agent/main.go | 15 +++-- cmd/agent/main_test.go | 21 ++++++ cmd/server/main_test.go | 67 +++++++++++++++++++ internal/infra/audit/fanout.go | 1 + internal/infra/config/helpers_test.go | 16 +++++ internal/infra/crypto/rsaenvelope/envelope.go | 1 + .../infra/crypto/rsaenvelope/envelope_test.go | 12 ++++ .../middlewares/crypto_decrypt_test.go | 58 ++++++++++++++++ internal/infra/retry/retry.go | 1 + 9 files changed, 186 insertions(+), 6 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index fd3f0aa..13ea720 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -45,12 +45,7 @@ func main() { log.Fatalf("failed to init publisher: %v", err) } collector := runtime.New() - appCfg := agentsvc.Config{ - PollInterval: cfg.PollInterval, - ReportInterval: cfg.ReportInterval, - RateLimit: cfg.RateLimit, - } - runner := agentsvc.New(appCfg, collector, pub) + runner := agentsvc.New(mapAgentConfig(cfg), collector, pub) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() @@ -65,3 +60,11 @@ func main() { func printBuildInfo() { util.PrintBuildInfo(buildVersion, buildDate, buildCommit) } + +func mapAgentConfig(cfg config.AgentConfig) agentsvc.Config { + return agentsvc.Config{ + PollInterval: cfg.PollInterval, + ReportInterval: cfg.ReportInterval, + RateLimit: cfg.RateLimit, + } +} diff --git a/cmd/agent/main_test.go b/cmd/agent/main_test.go index 66af1ec..5de4c01 100644 --- a/cmd/agent/main_test.go +++ b/cmd/agent/main_test.go @@ -2,6 +2,10 @@ package main import ( "testing" + "time" + + "github.com/vshulcz/Golectra/internal/application/agent" + "github.com/vshulcz/Golectra/internal/infra/config" ) func TestBuildVariablesExist(t *testing.T) { @@ -9,3 +13,20 @@ func TestBuildVariablesExist(t *testing.T) { _ = buildDate _ = buildCommit } + +func TestMapAgentConfig(t *testing.T) { + cfg := config.AgentConfig{ + PollInterval: 2 * time.Second, + ReportInterval: 5 * time.Second, + RateLimit: 3, + } + got := mapAgentConfig(cfg) + want := agent.Config{ + PollInterval: 2 * time.Second, + ReportInterval: 5 * time.Second, + RateLimit: 3, + } + if got != want { + t.Fatalf("mapAgentConfig=%+v want %+v", got, want) + } +} diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index 66af1ec..ad6263a 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -1,11 +1,78 @@ package main import ( + "context" + "net/http" "testing" + "time" + + "github.com/vshulcz/Golectra/internal/domain" + "github.com/vshulcz/Golectra/internal/infra/config" + "github.com/vshulcz/Golectra/internal/ports" + "go.uber.org/zap" ) +type fakePersister struct { + calls int + err error +} + +func (f *fakePersister) Save(_ context.Context, _ domain.Snapshot) error { + f.calls++ + return f.err +} + +func (f *fakePersister) Restore(context.Context, ports.MetricsRepo) error { + return nil +} + func TestBuildVariablesExist(t *testing.T) { _ = buildVersion _ = buildDate _ = buildCommit } + +func TestBuildSnapshotHook(t *testing.T) { + logger := zap.NewNop() + p := &fakePersister{} + hook := buildSnapshotHook(p, logger) + hook(context.Background(), domain.Snapshot{}) + if p.calls != 1 { + t.Fatalf("calls=%d want 1", p.calls) + } +} + +func TestLoadDecrypter_Empty(t *testing.T) { + cfg := config.ServerConfig{} + if dec, err := loadDecrypter(cfg); err != nil || dec != nil { + t.Fatalf("loadDecrypter=%v err=%v", dec, err) + } +} + +func TestNewHTTPServer(t *testing.T) { + cfg := config.ServerConfig{Address: "127.0.0.1:0"} + srv := newHTTPServer(cfg, http.NewServeMux()) + if srv.Addr != cfg.Address { + t.Fatalf("Addr=%q want %q", srv.Addr, cfg.Address) + } + if srv.ReadTimeout == 0 || srv.WriteTimeout == 0 { + t.Fatal("timeouts must be set") + } +} + +func TestServe_StartAndClose(t *testing.T) { + cfg := config.ServerConfig{Address: "127.0.0.1:0"} + srv := newHTTPServer(cfg, http.NewServeMux()) + + errCh := make(chan error, 1) + go func() { + errCh <- serve(srv) + }() + + time.Sleep(10 * time.Millisecond) + _ = srv.Close() + + if err := <-errCh; err != nil { + t.Fatalf("serve error: %v", err) + } +} diff --git a/internal/infra/audit/fanout.go b/internal/infra/audit/fanout.go index bc48dae..bec9ac3 100644 --- a/internal/infra/audit/fanout.go +++ b/internal/infra/audit/fanout.go @@ -1,3 +1,4 @@ +// Package audit provides infrastructure helpers for audit fan-out. package audit import ( diff --git a/internal/infra/config/helpers_test.go b/internal/infra/config/helpers_test.go index 80b11bd..d05cc7a 100644 --- a/internal/infra/config/helpers_test.go +++ b/internal/infra/config/helpers_test.go @@ -78,6 +78,13 @@ func TestHelpers_FromEnvOrFlagBool(t *testing.T) { def: true, expect: true, }, + { + name: "env explicit false wins", + env: "false", + flag: true, + def: true, + expect: false, + }, { name: "no env -> flag true used", env: "", @@ -221,6 +228,15 @@ func TestHelpers_FromEnvOrFlagDuration(t *testing.T) { expectDur: 300 * time.Second, expectCustom: true, }, + { + name: "env zero -> zero duration", + env: "0", + flagSeconds: 10, + sentinel: 0, + defSeconds: 300, + expectDur: 0 * time.Second, + expectCustom: true, + }, { name: "env with spaces numeric -> trimmed and used", env: " 7 ", diff --git a/internal/infra/crypto/rsaenvelope/envelope.go b/internal/infra/crypto/rsaenvelope/envelope.go index 88d2407..fdebf85 100644 --- a/internal/infra/crypto/rsaenvelope/envelope.go +++ b/internal/infra/crypto/rsaenvelope/envelope.go @@ -1,3 +1,4 @@ +// Package rsaenvelope provides RSA-OAEP + AES-GCM envelope encryption helpers. package rsaenvelope import ( diff --git a/internal/infra/crypto/rsaenvelope/envelope_test.go b/internal/infra/crypto/rsaenvelope/envelope_test.go index 985a8b6..807ad5f 100644 --- a/internal/infra/crypto/rsaenvelope/envelope_test.go +++ b/internal/infra/crypto/rsaenvelope/envelope_test.go @@ -78,3 +78,15 @@ func TestLoadKeys(t *testing.T) { t.Fatal("private key mismatch") } } + +func TestReadKeyFile_Errors(t *testing.T) { + if _, err := readKeyFile(""); err == nil { + t.Fatal("expected error for empty path") + } + if _, err := readKeyFile("../"); err == nil { + t.Fatal("expected error for invalid filename") + } + if _, err := readKeyFile(filepath.Join(t.TempDir(), "missing.pem")); err == nil { + t.Fatal("expected error for missing file") + } +} diff --git a/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go b/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go index d420de5..614abaa 100644 --- a/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go +++ b/internal/infra/http/ginserver/middlewares/crypto_decrypt_test.go @@ -52,3 +52,61 @@ func TestDecryptPayload(t *testing.T) { t.Fatalf("body=%q want %q", rec.Body.Bytes(), plain) } } + +func TestDecryptPayload_NoHeaderPassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + decrypter := rsaenvelope.NewDecrypter(priv) + + r := gin.New() + r.Use(DecryptPayload(decrypter)) + r.POST("/plain", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + c.Data(http.StatusOK, "text/plain", body) + }) + + req := httptest.NewRequest(http.MethodPost, "/plain", bytes.NewReader([]byte("plain"))) + rec := httptest.NewRecorder() + + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d want %d", rec.Code, http.StatusOK) + } + if got := rec.Body.String(); got != "plain" { + t.Fatalf("body=%q want %q", got, "plain") + } +} + +func TestDecryptPayload_InvalidCiphertext(t *testing.T) { + gin.SetMode(gin.TestMode) + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + decrypter := rsaenvelope.NewDecrypter(priv) + + r := gin.New() + r.Use(DecryptPayload(decrypter)) + r.POST("/decrypt", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodPost, "/decrypt", bytes.NewReader([]byte("bad"))) + req.Header.Set(decrypter.HeaderKey(), decrypter.HeaderValue()) + rec := httptest.NewRecorder() + + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status=%d want %d", rec.Code, http.StatusBadRequest) + } +} diff --git a/internal/infra/retry/retry.go b/internal/infra/retry/retry.go index 75b5dfd..89fc9f9 100644 --- a/internal/infra/retry/retry.go +++ b/internal/infra/retry/retry.go @@ -1,3 +1,4 @@ +// Package retry provides a simple retry helper with backoff. package retry import (