diff --git a/cmd/root.go b/cmd/root.go index 5ea3495..4ac4fb7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,14 +4,14 @@ import ( "fmt" "os" "strconv" - "time" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/anchore/ecs-inventory/internal/config" "github.com/anchore/ecs-inventory/pkg" - "github.com/anchore/ecs-inventory/pkg/reporter" + "github.com/anchore/ecs-inventory/pkg/healthreporter" + "github.com/anchore/ecs-inventory/pkg/integration" ) var ErrMissingDefaultConfigValue = fmt.Errorf("missing default config value") @@ -42,30 +42,34 @@ var rootCmd = &cobra.Command{ os.Exit(1) } - // Validate anchore connection & credentials, using a dummy report to post but this will be - // replaced in the future with a health check endpoint for the agents - if appConfig.AnchoreDetails.IsValid() { - dummyReport := reporter.Report{ - ClusterARN: "validating-creds", - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - err := reporter.Post(dummyReport, appConfig.AnchoreDetails) - if err != nil { - log.Error("Failed to validate connection to Anchore", err) - } else { - log.Info("Successfully validated connection to Anchore") - } - } else { + if !appConfig.AnchoreDetails.IsValid() { log.Warn("Anchore details not specified, will not report inventory") + pkg.PeriodicallyGetInventoryReportSimple( + appConfig.PollingIntervalSeconds, + appConfig.AnchoreDetails, + appConfig.Region, + appConfig.Quiet, + appConfig.DryRun, + ) + return + } + + // Channel-coordinated startup with health reporting + neverDone := make(chan bool, 1) + + ch := integration.GetChannels() + gatedReportInfo := healthreporter.GetGatedReportInfo() + + go healthreporter.PeriodicallySendHealthReport(appConfig, ch, gatedReportInfo) + go pkg.PeriodicallyGetInventoryReport(appConfig, ch, gatedReportInfo) + + _, err := integration.PerformRegistration(appConfig, ch) + if err != nil { + log.Error("Failed to perform registration with Anchore Enterprise", err) + os.Exit(1) } - pkg.PeriodicallyGetInventoryReport( - appConfig.PollingIntervalSeconds, - appConfig.AnchoreDetails, - appConfig.Region, - appConfig.Quiet, - appConfig.DryRun, - ) + <-neverDone }, } diff --git a/docker-compose/anchore-ecs-inventory.yaml b/docker-compose/anchore-ecs-inventory.yaml index fa0fd13..0436363 100644 --- a/docker-compose/anchore-ecs-inventory.yaml +++ b/docker-compose/anchore-ecs-inventory.yaml @@ -22,10 +22,22 @@ anchore: insecure: true timeout-seconds: 10 +anchore-registration: + # The id to register the agent as with Enterprise, so Enterprise can map the agent to its integration uuid. + # If left unspecified, the agent will generate a UUID to use as registration-id. + registration-id: + # The name that the agent should have. If left unspecified it will be empty. + integration-name: + # A short description for the agent + integration-description: + # the aws region region: $ANCHORE_ECS_INVENTORY_REGION # frequency of which to poll the region polling-interval-seconds: 300 +# frequency of which to send health reports to Anchore Enterprise (seconds, range: 30-600) +health-report-interval-seconds: 60 + quiet: false \ No newline at end of file diff --git a/go.mod b/go.mod index d6b55ca..61de896 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,13 @@ go 1.24.0 require ( github.com/adrg/xdg v0.5.3 - github.com/aws/aws-sdk-go v1.55.7 + github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2/config v1.32.15 + github.com/aws/aws-sdk-go-v2/service/ecs v1.78.0 + github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 + github.com/google/uuid v1.6.0 github.com/h2non/gock v1.2.0 + github.com/hashicorp/go-version v1.9.0 github.com/mitchellh/go-homedir v1.1.0 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 @@ -15,12 +20,21 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index b2bbcb2..ff90010 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,34 @@ github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= -github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= -github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= +github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/config v1.32.15 h1:i7rHbaySnBXGvCkDndaBU8f3EAlRVgViwNfkwFUrXgE= +github.com/aws/aws-sdk-go-v2/config v1.32.15/go.mod h1:yLJzL0IkI9+4BwjPSOueyHzppJj3t0dhK5tbmmcFk5Q= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/service/ecs v1.78.0 h1:P8s4jrrYr9CUPhoYXS0dI4Zi5oKXa6DWHUkeJ9m/gDQ= +github.com/aws/aws-sdk-go-v2/service/ecs v1.78.0/go.mod h1:QkWmubOYmjj3cHn7A4CoUU7BKJhVeo39Gp6NH7IyhZw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -14,16 +39,16 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/h2non/gock v1.2.0 h1:K6ol8rfrRkUOefooBC8elXoaNGYkpp7y2qcxGG6BzUE= github.com/h2non/gock v1.2.0/go.mod h1:tNhoxHYW2W42cYkYb1WqzdbYIieALC99kpYr7rH/BQk= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/hashicorp/go-version v1.9.0 h1:CeOIz6k+LoN3qX9Z0tyQrPtiB1DFYRPfCIBtaXPSCnA= +github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -38,7 +63,6 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWb github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= @@ -58,7 +82,6 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= @@ -76,7 +99,6 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/anchore/anchoreclient.go b/internal/anchore/anchoreclient.go new file mode 100644 index 0000000..6d9f0e7 --- /dev/null +++ b/internal/anchore/anchoreclient.go @@ -0,0 +1,312 @@ +package anchore + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "syscall" + "time" + + "github.com/h2non/gock" + + "github.com/anchore/ecs-inventory/internal/logger" + "github.com/anchore/ecs-inventory/internal/tracker" + "github.com/anchore/ecs-inventory/pkg/connection" +) + +type Version struct { + API struct { + Version string `json:"version"` + } `json:"api"` + DB struct { + SchemaVersion string `json:"schema_version"` + } `json:"db"` + Service struct { + Version string `json:"version"` + } `json:"service"` +} + +type ControllerErrorDetails struct { + Type string `json:"type"` + Title string `json:"title"` + Detail string `json:"detail"` + Status int `json:"status"` +} + +type APIErrorDetails struct { + Message string `json:"message"` + Detail map[string]interface{} `json:"detail"` + HTTPCode int `json:"httpcode"` +} + +type APIClientError struct { + HTTPStatusCode int + Message string + Path string + Method string + Body *[]byte + APIErrorDetails *APIErrorDetails + ControllerErrorDetails *ControllerErrorDetails +} + +func (e *APIClientError) Error() string { + return fmt.Sprintf("API errorMsg(%d): %s Path: %q %v %v", e.HTTPStatusCode, e.Message, e.Path, + e.APIErrorDetails, e.ControllerErrorDetails) +} + +func GetVersion(anchoreDetails connection.AnchoreInfo) (*Version, error) { + operation := "version get" + defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Sent %s request to Anchore", operation)) + + logger.Log.Debug("Determining Anchore service version") + + client := getClient(anchoreDetails) + + response, err := client.Get(anchoreDetails.URL + "/version") + if err != nil { + return nil, err + } + defer response.Body.Close() + + err = checkHTTPErrors(response, operation) + if err != nil { + return nil, err + } + + responseBody, err := getBody(response, operation) + if err != nil { + return nil, err + } + + ver := Version{} + err = json.Unmarshal(*responseBody, &ver) + if err != nil { + return nil, fmt.Errorf("failed to parse API version: %w", err) + } + return &ver, nil +} + +func Post(requestBody []byte, id string, path string, anchoreDetails connection.AnchoreInfo, operation string) (*[]byte, error) { + defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Sent %s request to Anchore", operation)) + + logger.Log.Debugf("Performing %s to Anchore using endpoint: %s", operation, strings.Replace(path, "{{id}}", id, 1)) + + client := getClient(anchoreDetails) + + anchoreURL, err := getURL(anchoreDetails, path, id) + if err != nil { + return nil, err + } + + request, err := getPostRequest(anchoreDetails, anchoreURL, requestBody, operation) + if err != nil { + return nil, err + } + + return doPost(client, request, operation) +} + +func getClient(anchoreDetails connection.AnchoreInfo) *http.Client { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: anchoreDetails.HTTP.Insecure}, + } // #nosec G402 + + client := &http.Client{ + Transport: tr, + Timeout: time.Duration(anchoreDetails.HTTP.TimeoutSeconds) * time.Second, + } + gock.InterceptClient(client) // Required to use gock for testing custom client + + return client +} + +func getURL(anchoreDetails connection.AnchoreInfo, path string, id string) (string, error) { + anchoreURL, err := url.Parse(anchoreDetails.URL) + if err != nil { + return "", fmt.Errorf("failed to build path (%s) url: %w", path, err) + } + + anchoreURL.Path += strings.Replace(path, "{{id}}", id, 1) + return anchoreURL.String(), nil +} + +func getPostRequest(anchoreDetails connection.AnchoreInfo, endpointURL string, reqBody []byte, operation string) (*http.Request, error) { + request, err := http.NewRequest("POST", endpointURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to prepare %s request to Anchore: %w", operation, err) + } + + request.SetBasicAuth(anchoreDetails.User, anchoreDetails.Password) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("x-anchore-account", anchoreDetails.Account) + return request, nil +} + +func doPost(client *http.Client, request *http.Request, operation string) (*[]byte, error) { + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + err = checkHTTPErrors(response, operation) + if err != nil { + return nil, err + } + + responseBody, err := getBody(response, operation) + return responseBody, err +} + +func checkHTTPErrors(response *http.Response, operation string) error { + switch { + case response.StatusCode >= 400 && response.StatusCode <= 599: + msg := fmt.Sprintf("%s response from Anchore (during %s)", response.Status, operation) + logger.Log.Errorf(msg) + + respBody, _ := getBody(response, operation) + if respBody == nil { + return &APIClientError{Message: msg, Path: response.Request.URL.Path, Method: response.Request.Method, + Body: nil, HTTPStatusCode: response.StatusCode} + } + + // Depending on where an error is discovered during request processing on the server, the + // error information in the response will be either an APIErrorDetails or a ControllerErrorDetails + apiError := APIErrorDetails{} + err := json.Unmarshal(*respBody, &apiError) + if err == nil { + return &APIClientError{Message: msg, Path: response.Request.URL.Path, Method: response.Request.Method, + Body: nil, HTTPStatusCode: response.StatusCode, APIErrorDetails: &apiError} + } + + controllerError := ControllerErrorDetails{} + err = json.Unmarshal(*respBody, &controllerError) + if err == nil { + return &APIClientError{Message: msg, Path: response.Request.URL.Path, Method: response.Request.Method, + Body: nil, HTTPStatusCode: response.StatusCode, ControllerErrorDetails: &controllerError} + } + + return &APIClientError{Message: msg, Path: response.Request.URL.Path, Method: response.Request.Method, + Body: nil, HTTPStatusCode: response.StatusCode} + case response.StatusCode < 200 || response.StatusCode > 299: + msg := fmt.Sprintf("failed to perform %s to Anchore: %+v", operation, response) + logger.Log.Debugf(msg) + return &APIClientError{Message: msg, Path: response.Request.URL.Path, Method: response.Request.Method, + Body: nil, HTTPStatusCode: response.StatusCode} + } + return nil +} + +func getBody(response *http.Response, operation string) (*[]byte, error) { + responseBody, err := io.ReadAll(response.Body) + if err != nil { + errMsg := fmt.Sprintf("failed to read %s response body from Anchore:", operation) + logger.Log.Debugf("%s %v", operation, errMsg) + return nil, fmt.Errorf("%s %w", errMsg, err) + } + + // Check we received a valid JSON response from Anchore, this will help catch + // any redirect responses where it returns HTML login pages e.g. Enterprise + // running behind cloudflare where a login page is returned with the status 200 + if len(responseBody) > 0 && !json.Valid(responseBody) { + logger.Log.Debugf("Anchore %s response body: %s", operation, string(responseBody)) + return nil, fmt.Errorf("%s response from Anchore is not valid json: %+v", operation, response) + } + return &responseBody, nil +} + +func ServerIsOffline(err error) bool { + if os.IsTimeout(err) { + return true + } + + offlineErrors := []error{ + syscall.ENETDOWN, + syscall.ENETUNREACH, + syscall.ENETRESET, + syscall.ECONNABORTED, + syscall.ECONNRESET, + syscall.ETIMEDOUT, + syscall.ECONNREFUSED, + syscall.EHOSTDOWN, + syscall.EHOSTUNREACH, + } + + for _, e := range offlineErrors { + if errors.Is(err, e) { + return true + } + } + + var dnsError *net.DNSError + if errors.As(err, &dnsError) { + return true + } + + var apiClientError *APIClientError + if errors.As(err, &apiClientError) { + if apiClientError.HTTPStatusCode == http.StatusBadGateway || + apiClientError.HTTPStatusCode == http.StatusServiceUnavailable || + apiClientError.HTTPStatusCode == http.StatusGatewayTimeout { + return true + } + } + + return false +} + +func ServerLacksAgentHealthAPISupport(err error) bool { + var apiClientError *APIClientError + if errors.As(err, &apiClientError) { + if apiClientError.ControllerErrorDetails == nil { + return false + } + + if apiClientError.HTTPStatusCode == http.StatusNotFound && + strings.Contains(apiClientError.ControllerErrorDetails.Detail, "The requested URL was not found") { + return true + } + + if apiClientError.HTTPStatusCode == http.StatusMethodNotAllowed && + apiClientError.ControllerErrorDetails.Detail == "Method Not Allowed" { + return true + } + } + + return false +} + +func UserLacksAPIPrivileges(err error) bool { + var apiClientError *APIClientError + + if errors.As(err, &apiClientError) { + if apiClientError.APIErrorDetails == nil { + return false + } + + if apiClientError.HTTPStatusCode == http.StatusForbidden && + strings.Contains(apiClientError.APIErrorDetails.Message, "Not authorized. Requires permissions") { + return true + } + } + return false +} + +func IncorrectCredentials(err error) bool { + // This covers user that does not exist or incorrect password for user + var apiClientError *APIClientError + + if errors.As(err, &apiClientError) && apiClientError.HTTPStatusCode == http.StatusUnauthorized { + return true + } + + return false +} diff --git a/internal/config/config.go b/internal/config/config.go index 92cb8e9..ff5a4ac 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,13 +33,21 @@ type CliOnlyOptions struct { } type AppConfig struct { - Log Logging `mapstructure:"log"` - CliOptions CliOnlyOptions - PollingIntervalSeconds int `mapstructure:"polling-interval-seconds"` - AnchoreDetails connection.AnchoreInfo `mapstructure:"anchore"` - Region string `mapstructure:"region"` - Quiet bool `mapstructure:"quiet"` // if true do not log the inventory report to stdout - DryRun bool `mapstructure:"dry-run"` // if true do not report inventory to Anchore + Log Logging `mapstructure:"log"` + CliOptions CliOnlyOptions + PollingIntervalSeconds int `mapstructure:"polling-interval-seconds"` + HealthReportIntervalSeconds int `mapstructure:"health-report-interval-seconds"` + AnchoreDetails connection.AnchoreInfo `mapstructure:"anchore"` + Registration RegistrationOptions `mapstructure:"anchore-registration"` + Region string `mapstructure:"region"` + Quiet bool `mapstructure:"quiet"` // if true do not log the inventory report to stdout + DryRun bool `mapstructure:"dry-run"` // if true do not report inventory to Anchore +} + +type RegistrationOptions struct { + RegistrationID string `mapstructure:"registration-id"` + IntegrationName string `mapstructure:"integration-name"` + IntegrationDescription string `mapstructure:"integration-description"` } // Logging Configuration @@ -60,10 +68,11 @@ var DefaultConfigValues = AppConfig{ TimeoutSeconds: 60, }, }, - Region: "", - PollingIntervalSeconds: 300, - Quiet: false, - DryRun: false, + Region: "", + PollingIntervalSeconds: 300, + HealthReportIntervalSeconds: 60, + Quiet: false, + DryRun: false, } var ErrConfigFileNotFound = fmt.Errorf("application config file not found") @@ -71,9 +80,20 @@ var ErrConfigFileNotFound = fmt.Errorf("application config file not found") func setDefaultValues(v *viper.Viper) { v.SetDefault("log.level", DefaultConfigValues.Log.Level) v.SetDefault("log.file", DefaultConfigValues.Log.FileLocation) + v.SetDefault("anchore.url", DefaultConfigValues.AnchoreDetails.URL) + v.SetDefault("anchore.user", DefaultConfigValues.AnchoreDetails.User) + v.SetDefault("anchore.password", DefaultConfigValues.AnchoreDetails.Password) v.SetDefault("anchore.account", DefaultConfigValues.AnchoreDetails.Account) v.SetDefault("anchore.http.insecure", DefaultConfigValues.AnchoreDetails.HTTP.Insecure) v.SetDefault("anchore.http.timeout-seconds", DefaultConfigValues.AnchoreDetails.HTTP.TimeoutSeconds) + v.SetDefault("region", DefaultConfigValues.Region) + v.SetDefault("polling-interval-seconds", DefaultConfigValues.PollingIntervalSeconds) + v.SetDefault("health-report-interval-seconds", DefaultConfigValues.HealthReportIntervalSeconds) + v.SetDefault("quiet", DefaultConfigValues.Quiet) + v.SetDefault("dry-run", DefaultConfigValues.DryRun) + v.SetDefault("anchore-registration.registration-id", DefaultConfigValues.Registration.RegistrationID) + v.SetDefault("anchore-registration.integration-name", DefaultConfigValues.Registration.IntegrationName) + v.SetDefault("anchore-registration.integration-description", DefaultConfigValues.Registration.IntegrationDescription) } // Load the Application Configuration from the Viper specifications @@ -133,6 +153,10 @@ func (cfg *AppConfig) Build() error { } } + if cfg.HealthReportIntervalSeconds < 30 || cfg.HealthReportIntervalSeconds > 600 { + return fmt.Errorf("health-report-interval-seconds must be between 30 and 600") + } + return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 13199eb..9112ddc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -37,9 +37,10 @@ func TestLoadConfigFromFileCliConfigPath(t *testing.T) { TimeoutSeconds: 10, }, }, - Region: "us-east-1", - PollingIntervalSeconds: 60, - Quiet: true, + Region: "us-east-1", + PollingIntervalSeconds: 60, + HealthReportIntervalSeconds: 60, + Quiet: true, } assert.EqualValues(t, expectedCfg, appCfg) @@ -89,6 +90,7 @@ clioptions: configpath: testdata/config.yaml verbosity: 0 pollingintervalseconds: 300 +healthreportintervalseconds: 0 anchoredetails: url: http://localhost:8228/v1 user: admin @@ -97,6 +99,10 @@ anchoredetails: http: insecure: false timeoutseconds: 0 +registration: + registrationid: "" + integrationname: "" + integrationdescription: "" region: "" quiet: false dryrun: false @@ -124,6 +130,8 @@ func TestDefaultValuesSuppliedForEmptyConfig(t *testing.T) { Log: Logging{ Level: "info", }, + PollingIntervalSeconds: 300, + HealthReportIntervalSeconds: 60, AnchoreDetails: connection.AnchoreInfo{ Account: "admin", Password: "", diff --git a/internal/logger/logger.go b/internal/logger/logger.go index d613c69..e819885 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -17,12 +17,16 @@ func (log NoOpLogger) Debugf(string, ...interface{}) {} func (log NoOpLogger) Info(string, ...interface{}) {} +func (log NoOpLogger) Infof(string, ...interface{}) {} + func (log NoOpLogger) Warn(string, ...interface{}) {} func (log NoOpLogger) Warnf(string, ...interface{}) {} func (log NoOpLogger) Error(string, error, ...interface{}) {} +func (log NoOpLogger) Errorf(string, ...interface{}) {} + type ZapLogger struct { zap *zap.SugaredLogger } @@ -47,12 +51,20 @@ func (log ZapLogger) Warnf(msg string, args ...interface{}) { log.zap.Warnf(msg, args...) } +func (log ZapLogger) Infof(msg string, args ...interface{}) { + log.zap.Infof(msg, args...) +} + func (log ZapLogger) Error(msg string, err error, args ...interface{}) { args = append(args, "err", err) log.zap.Errorw(msg, args...) } +func (log ZapLogger) Errorf(msg string, args ...interface{}) { + log.zap.Errorf(msg, args...) +} + type LogConfig struct { Level string FileLocation string diff --git a/internal/time/time.go b/internal/time/time.go new file mode 100644 index 0000000..58b984a --- /dev/null +++ b/internal/time/time.go @@ -0,0 +1,42 @@ +package time + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +// time with json marshalling/unmarshalling support +const nanoSeconds = 1000000000 + +type Datetime struct { + time.Time +} + +type Duration struct { + time.Duration +} + +func (d Datetime) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("\"%s\"", d.Format(time.RFC3339))), nil +} + +func (d *Duration) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("%f", d.Seconds())), nil +} + +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + // enterprise sends durations in seconds + d.Duration = time.Duration(value * nanoSeconds) + return nil + default: + return errors.New("invalid duration") + } +} diff --git a/pkg/healthreporter/healthreporter.go b/pkg/healthreporter/healthreporter.go new file mode 100644 index 0000000..db32f66 --- /dev/null +++ b/pkg/healthreporter/healthreporter.go @@ -0,0 +1,181 @@ +package healthreporter + +import ( + "encoding/json" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/anchore/ecs-inventory/internal/anchore" + "github.com/anchore/ecs-inventory/internal/config" + "github.com/anchore/ecs-inventory/internal/logger" + jstime "github.com/anchore/ecs-inventory/internal/time" + intg "github.com/anchore/ecs-inventory/pkg/integration" +) + +const healthProtocolVersion = 1 +const healthDataVersion = 1 +const healthDataType = "ecs_inventory_agent" +const HealthReportAPIPathV2 = "v2/system/integrations/{{id}}/health-report" + +type HealthReport struct { + UUID string `json:"uuid,omitempty"` + ProtocolVersion int `json:"protocol_version,omitempty"` + Timestamp jstime.Datetime `json:"timestamp,omitempty"` + Uptime *jstime.Duration `json:"uptime,omitempty"` + HealthReportInterval int `json:"health_report_interval,omitempty"` + HealthData HealthData `json:"health_data,omitempty"` +} + +type HealthData struct { + Type string `json:"type,omitempty"` + Version int `json:"version,omitempty"` + Errors HealthReportErrors `json:"errors,omitempty"` + // ECS-specific: latest inventory reports per account/region + AccountECSInventoryReports AccountECSInventoryReports `json:"account_ecs_inventory_reports,omitempty"` +} + +type HealthReportErrors []string + +// AccountECSInventoryReports holds per account information about latest inventory reports from the same batch set +type AccountECSInventoryReports map[string]InventoryReportInfo + +type InventoryReportInfo struct { + ReportTimestamp string `json:"report_timestamp"` + Account string `json:"account_name"` + Region string `json:"region"` + SentAsUser string `json:"sent_as_user"` + BatchSize int `json:"batch_size"` + LastSuccessfulIndex int `json:"last_successful_index"` + HasErrors bool `json:"has_errors"` + Batches []BatchInfo `json:"batches"` +} + +type BatchInfo struct { + BatchIndex int `json:"batch_index,omitempty"` + SendTimestamp jstime.Datetime `json:"send_timestamp,omitempty"` + Error string `json:"error,omitempty"` +} + +// GatedReportInfo The go routine that generates the inventory report must inform the go routine +// that sends health reports about the *latest* sent inventory reports. +// We use a map (keyed by account) to store information about the latest sent inventory +// reports. This map is shared by the go routine that generates inventory reports and the go +// routine that sends health reports. Access to the map is coordinated by a mutex. +type GatedReportInfo struct { + AccessGate sync.RWMutex + AccountInventoryReports AccountECSInventoryReports +} + +type _NewUUID func() uuid.UUID +type _Now func() time.Time + +func GetGatedReportInfo() *GatedReportInfo { + return &GatedReportInfo{ + AccountInventoryReports: make(AccountECSInventoryReports), + } +} + +func PeriodicallySendHealthReport(cfg *config.AppConfig, ch intg.Channels, gatedReportInfo *GatedReportInfo) { + // Wait for registration with Enterprise to be completed + integration := <-ch.IntegrationObj + logger.Log.Info("Health reporting started") + + ticker := time.NewTicker(time.Duration(cfg.HealthReportIntervalSeconds) * time.Second) + + for { + logger.Log.Infof("Waiting %d seconds to send health report...", cfg.HealthReportIntervalSeconds) + + _, _ = sendHealthReport(cfg, integration, gatedReportInfo, uuid.New, time.Now) + <-ticker.C + } +} + +func sendHealthReport(cfg *config.AppConfig, integration *intg.Integration, gatedReportInfo *GatedReportInfo, newUUID _NewUUID, _now _Now) (*HealthReport, error) { + healthReportID := newUUID().String() + lastReports := GetAccountReportInfoNoBlocking(gatedReportInfo, cfg, _now) + + now := _now().UTC() + integration.Uptime = &jstime.Duration{Duration: now.Sub(integration.StartedAt.Time)} + healthReport := HealthReport{ + UUID: healthReportID, + ProtocolVersion: healthProtocolVersion, + Timestamp: jstime.Datetime{Time: now}, + Uptime: integration.Uptime, + HealthData: HealthData{ + Type: healthDataType, + Version: healthDataVersion, + Errors: make(HealthReportErrors, 0), + AccountECSInventoryReports: lastReports, + }, + HealthReportInterval: cfg.HealthReportIntervalSeconds, + } + + logger.Log.Infof("Sending health report (uuid:%s) covering %d accounts", healthReport.UUID, len(healthReport.HealthData.AccountECSInventoryReports)) + requestBody, err := json.Marshal(healthReport) + if err != nil { + logger.Log.Errorf("failed to serialize health report as JSON: %v", err) + return nil, err + } + _, err = anchore.Post(requestBody, integration.UUID, HealthReportAPIPathV2, cfg.AnchoreDetails, "health report") + if err != nil { + logger.Log.Errorf("Failed to send health report to Anchore: %v", err) + return nil, err + } + return &healthReport, nil +} + +func GetAccountReportInfoNoBlocking(gatedReportInfo *GatedReportInfo, cfg *config.AppConfig, _now _Now) AccountECSInventoryReports { + locked := gatedReportInfo.AccessGate.TryLock() + + if locked { + defer gatedReportInfo.AccessGate.Unlock() + + logger.Log.Debugf("Removing inventory report info for accounts that are no longer active") + accountsToRemove := make(map[string]bool) + now := _now().UTC() + inactiveAge := 2 * float64(cfg.PollingIntervalSeconds) + + for account, reportInfo := range gatedReportInfo.AccountInventoryReports { + for _, batchInfo := range reportInfo.Batches { + logger.Log.Debugf("Last inv.report (time:%s, account:%s, batch:%d/%d, sent:%s error:'%s')", + reportInfo.ReportTimestamp, account, batchInfo.BatchIndex, reportInfo.BatchSize, + batchInfo.SendTimestamp, batchInfo.Error) + reportTime, err := time.Parse(time.RFC3339, reportInfo.ReportTimestamp) + if err != nil { + logger.Log.Errorf("failed to parse report_timestamp: %v", err) + continue + } + if now.Sub(reportTime).Seconds() > inactiveAge { + accountsToRemove[account] = true + } + } + } + + for accountToRemove := range accountsToRemove { + logger.Log.Debugf("Accounts no longer considered active: %s", accountToRemove) + delete(gatedReportInfo.AccountInventoryReports, accountToRemove) + } + + return gatedReportInfo.AccountInventoryReports + } + logger.Log.Debugf("Unable to obtain mutex lock to get account inventory report information. Continuing.") + return AccountECSInventoryReports{} +} + +func SetReportInfoNoBlocking(accountName string, count int, reportInfo InventoryReportInfo, gatedReportInfo *GatedReportInfo) { + logger.Log.Debugf("Setting report (%s) for account name '%s': %d/%d %s %s", reportInfo.ReportTimestamp, accountName, + reportInfo.Batches[count].BatchIndex, reportInfo.BatchSize, reportInfo.Batches[count].SendTimestamp, + reportInfo.Batches[count].Error) + locked := gatedReportInfo.AccessGate.TryLock() + if locked { + defer gatedReportInfo.AccessGate.Unlock() + gatedReportInfo.AccountInventoryReports[accountName] = reportInfo + } else { + // we prioritize no blocking over actually bookkeeping info for every sent inventory report + logger.Log.Debugf("Unable to obtain mutex lock to include inventory report timestamped %s for %s: %d/%d %s in health report. Continuing.", + reportInfo.ReportTimestamp, accountName, reportInfo.Batches[count].BatchIndex, reportInfo.BatchSize, + reportInfo.Batches[count].SendTimestamp) + } +} diff --git a/pkg/healthreporter/healthreporter_test.go b/pkg/healthreporter/healthreporter_test.go new file mode 100644 index 0000000..2db5c26 --- /dev/null +++ b/pkg/healthreporter/healthreporter_test.go @@ -0,0 +1,282 @@ +package healthreporter + +import ( + "fmt" + "net/http" + "reflect" + "testing" + "time" + + "github.com/anchore/ecs-inventory/internal/anchore" + "github.com/anchore/ecs-inventory/internal/config" + jstime "github.com/anchore/ecs-inventory/internal/time" + "github.com/anchore/ecs-inventory/pkg/connection" + "github.com/anchore/ecs-inventory/pkg/integration" + "github.com/google/uuid" + "github.com/h2non/gock" + "github.com/stretchr/testify/assert" +) + +const mutexLocked = int64(1 << iota) // mutex is locked + +var ( + now = time.Date(2024, 10, 4, 10, 11, 12, 0, time.Local) + timestamps = []time.Time{now.Add(time.Millisecond * 10), now.Add(time.Millisecond * 20), now.Add(time.Millisecond * 30)} + uuids = []uuid.UUID{uuid.New(), uuid.New()} + + reportInfo = InventoryReportInfo{ + ReportTimestamp: now.UTC().Format(time.RFC3339), + Account: "testAccount", + Region: "us-east-1", + SentAsUser: "testAccountUser", + BatchSize: 1, + LastSuccessfulIndex: 1, + HasErrors: false, + Batches: []BatchInfo{ + { + BatchIndex: 0, + SendTimestamp: jstime.Datetime{Time: time.Now().UTC()}, + Error: "", + }, + }, + } + reportInfoExpired = InventoryReportInfo{ + ReportTimestamp: now.Add(time.Second * (-3800)).UTC().Format(time.RFC3339), + Account: "testAccount2", + Region: "us-west-2", + SentAsUser: "testAccount2User", + BatchSize: 1, + LastSuccessfulIndex: 1, + HasErrors: false, + Batches: []BatchInfo{ + { + BatchIndex: 0, + SendTimestamp: jstime.Datetime{Time: time.Now().UTC()}, + Error: "", + }, + }, + } +) + +func TestSendHealthReport(t *testing.T) { + defer gock.Off() + + integrationUUID := uuid.New().String() + postURL := fmt.Sprintf("/v2/system/integrations/%s/health-report", integrationUUID) + type want struct { + healthReport *HealthReport + err error + } + tests := []struct { + name string + want want + }{ + { + name: "successful health report", + want: want{ + healthReport: &HealthReport{ + UUID: uuids[0].String(), + ProtocolVersion: 1, + Timestamp: jstime.Datetime{Time: timestamps[1].UTC()}, + Uptime: &jstime.Duration{Duration: time.Millisecond * 20}, + HealthData: HealthData{ + Type: healthDataType, + Version: healthDataVersion, + Errors: make(HealthReportErrors, 0), + AccountECSInventoryReports: AccountECSInventoryReports{ + reportInfo.Account: reportInfo, + }, + }, + HealthReportInterval: 60, + }, + err: nil, + }, + }, + { + name: "failed health report", + want: want{ + healthReport: nil, + err: &anchore.APIClientError{ + HTTPStatusCode: http.StatusUnauthorized, + Message: "401 Unauthorized response from Anchore (during health report)", + Path: postURL, + Method: "POST", + }, + }, + }, + } + for _, tt := range tests { + cfg := config.AppConfig{ + AnchoreDetails: connection.AnchoreInfo{ + URL: "https://ancho.re", + User: "admin", + }, + PollingIntervalSeconds: 30 * 60, + HealthReportIntervalSeconds: 60, + } + integrationInstance := &integration.Integration{ + UUID: integrationUUID, + StartedAt: jstime.Datetime{Time: now.UTC()}, + Uptime: &jstime.Duration{Duration: time.Millisecond * 20}, + HealthReportInterval: 60, + } + gatedReportInfo := GetGatedReportInfo() + gatedReportInfo.AccountInventoryReports["testAccount"] = reportInfo + i := 0 + newUUIDMock := func() uuid.UUID { + _uuid := uuids[i] + i++ + return _uuid + } + j := 0 + nowMock := func() time.Time { + timestamp := timestamps[j] + j++ + return timestamp + } + switch tt.name { + case "successful health report": + gock.New("https://ancho.re"). + Post(postURL). + Reply(200) + case "failed health report": + gock.New("https://ancho.re"). + Post(postURL). + Reply(http.StatusUnauthorized) + } + t.Run(tt.name, func(t *testing.T) { + result, resultErr := sendHealthReport(&cfg, integrationInstance, gatedReportInfo, newUUIDMock, nowMock) + if tt.want.err != nil { + assert.Equal(t, tt.want.err, resultErr) + assert.Nil(t, result) + } else { + assert.NoError(t, resultErr) + assert.Equal(t, tt.want.healthReport, result) + } + }) + } +} + +func TestGetAccountReportInfoNoBlockingWhenObtainingLockRemovesExpired(t *testing.T) { + gatedReportInfo := GatedReportInfo{ + AccountInventoryReports: make(AccountECSInventoryReports, 2), + } + gatedReportInfo.AccountInventoryReports[reportInfo.Account] = reportInfo + gatedReportInfo.AccountInventoryReports[reportInfoExpired.Account] = reportInfoExpired + + cfg := config.AppConfig{ + PollingIntervalSeconds: 30 * 60, + } + + nowMock := func() time.Time { + return now + } + + result := GetAccountReportInfoNoBlocking(&gatedReportInfo, &cfg, nowMock) + assert.Equal(t, len(result), 1) + assert.Contains(t, result, reportInfo.Account) + assert.Equal(t, len(gatedReportInfo.AccountInventoryReports), 1) + assert.Contains(t, gatedReportInfo.AccountInventoryReports, reportInfo.Account) +} + +func TestGetAccountReportInfoBlockingWhenNotObtainingLockExpiredUnaffected(t *testing.T) { + gatedReportInfo := GatedReportInfo{ + AccountInventoryReports: make(AccountECSInventoryReports, 2), + } + gatedReportInfo.AccountInventoryReports[reportInfo.Account] = reportInfo + gatedReportInfo.AccountInventoryReports[reportInfoExpired.Account] = reportInfoExpired + gatedReportInfo.AccessGate.Lock() + + cfg := config.AppConfig{ + PollingIntervalSeconds: 3 * 60, + } + + nowMock := func() time.Time { + return now + } + + result := GetAccountReportInfoNoBlocking(&gatedReportInfo, &cfg, nowMock) + assert.Equal(t, len(result), 0) + assert.Equal(t, len(gatedReportInfo.AccountInventoryReports), 2) + assert.Contains(t, gatedReportInfo.AccountInventoryReports, reportInfo.Account) + assert.Contains(t, gatedReportInfo.AccountInventoryReports, reportInfoExpired.Account) + // check mutex is still locked after operation + wField := reflect.ValueOf(&gatedReportInfo.AccessGate).Elem().FieldByName("w") + if wField.IsValid() { + stateField := wField.FieldByName("state") + if stateField.IsValid() { + if stateField.Kind() == reflect.Int { + assert.Equal(t, stateField.Int()&mutexLocked, mutexLocked) + } else { + t.Errorf("Expected field 'state' to be of type int, but got %s", stateField.Kind()) + } + } else { + // We don't want to error here as we're expecting one empty struct to be returned from go 1.24 onwards + t.Logf("Field 'state' does not exist in the 'w' field") + } + } else { + t.Errorf("Field 'w' does not exist in the AccessGate struct") + } +} + +func TestSetReportInfoNoBlockingSetsWhenObtainingLock(t *testing.T) { + gatedReportInfo := GatedReportInfo{ + AccountInventoryReports: make(AccountECSInventoryReports, 1), + } + accountName := "testAccount" + count := 0 + + SetReportInfoNoBlocking(accountName, count, reportInfo, &gatedReportInfo) + + assert.Equal(t, reportInfo, gatedReportInfo.AccountInventoryReports[accountName]) + // check mutex is unlocked after operation + wField := reflect.ValueOf(&gatedReportInfo.AccessGate).Elem().FieldByName("w") + if wField.IsValid() { + stateField := wField.FieldByName("state") + if stateField.IsValid() { + if stateField.Kind() == reflect.Int { + mutexState := reflect.ValueOf(&gatedReportInfo.AccessGate).Elem().FieldByName("w").FieldByName("state") + assert.Equal(t, mutexState.Int()&mutexLocked, int64(0)) + } else { + t.Errorf("Expected field 'state' to be of type int, but got %s", stateField.Kind()) + } + } else { + // We don't want to error here as we're expecting one empty struct to be returned from go 1.24 onwards + t.Logf("Field 'state' does not exist in the 'w' field") + } + } else { + t.Errorf("Field 'w' does not exist in the AccessGate struct") + } +} + +func TestSetReportInfoNoBlockingSkipsWhenLockAlreadyTaken(t *testing.T) { + gatedReportInfo := GatedReportInfo{ + AccountInventoryReports: make(AccountECSInventoryReports, 1), + } + gatedReportInfo.AccessGate.Lock() + + accountName := "testAccount" + count := 0 + + SetReportInfoNoBlocking(accountName, count, reportInfo, &gatedReportInfo) + + assert.NotContains(t, gatedReportInfo.AccountInventoryReports, accountName) + // check mutex is still locked after operation + wField := reflect.ValueOf(&gatedReportInfo.AccessGate).Elem().FieldByName("w") + if wField.IsValid() { + stateField := wField.FieldByName("state") + if stateField.IsValid() { + if stateField.Kind() == reflect.Int { + mutexState := reflect.ValueOf(&gatedReportInfo.AccessGate).Elem().FieldByName("w").FieldByName("state") + assert.Equal(t, mutexState.Int()&mutexLocked, mutexLocked) + } else { + t.Errorf("Expected field 'state' to be of type int, but got %s", stateField.Kind()) + } + } else { + // We don't want to error here as we're expecting one empty struct to be returned from go 1.24 onwards + t.Logf("Field 'state' does not exist in the 'w' field") + } + } else { + t.Errorf("Field 'w' does not exist in the AccessGate struct") + } +} diff --git a/pkg/integration/integration.go b/pkg/integration/integration.go new file mode 100644 index 0000000..25026c3 --- /dev/null +++ b/pkg/integration/integration.go @@ -0,0 +1,333 @@ +package integration + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/google/uuid" + "github.com/hashicorp/go-version" + + "github.com/anchore/ecs-inventory/internal/anchore" + "github.com/anchore/ecs-inventory/internal/config" + "github.com/anchore/ecs-inventory/internal/logger" + jstime "github.com/anchore/ecs-inventory/internal/time" + ecsVersion "github.com/anchore/ecs-inventory/internal/version" + "github.com/anchore/ecs-inventory/pkg/connection" +) + +var requiredAnchoreVersion, _ = version.NewVersion("5.11") + +var inventoryReportingActive = false + +const Type = "ecs_inventory_agent" +const RegisterAPIPathV2 = "v2/system/integrations/registration" + +type Channels struct { + IntegrationObj chan *Integration + HealthReportingEnabled chan bool + InventoryReportingEnabled chan bool +} + +// HealthStatus reflects the state of the Integration wrt any errors +// encountered when performing its tasks +type HealthStatus struct { + State string `json:"state,omitempty"` // state of the integration HEALTHY or UNHEALTHY + Reason string `json:"reason,omitempty"` + Details any `json:"details,omitempty"` +} + +// LifeCycleStatus reflects the state of the Integration from the perspective of Enterprise +type LifeCycleStatus struct { + State string `json:"state,omitempty"` // lifecycle state REGISTERED, ACTIVE, DEGRADED, DEACTIVATED + Reason string `json:"reason,omitempty"` + Details any `json:"details,omitempty"` + UpdatedAt jstime.Datetime `json:"updated_at,omitempty"` +} + +type Integration struct { + UUID string `json:"uuid,omitempty"` + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Version string `json:"version,omitempty"` + ReportedStatus *HealthStatus `json:"reported_status,omitempty"` + IntegrationStatus *LifeCycleStatus `json:"integration_status,omitempty"` + StartedAt jstime.Datetime `json:"started_at,omitempty"` + LastSeen *jstime.Datetime `json:"last_seen,omitempty"` + Uptime *jstime.Duration `json:"uptime,omitempty"` + Username string `json:"username,omitempty"` + AccountName string `json:"account_name,omitempty"` + Region string `json:"region,omitempty"` + Configuration map[string]interface{} `json:"configuration,omitempty"` + HealthReportInterval int `json:"health_report_interval,omitempty"` + RegistrationID string `json:"registration_id,omitempty"` + RegistrationInstanceID string `json:"registration_instance_id,omitempty"` +} + +type Registration struct { + RegistrationID string `json:"registration_id,omitempty"` + RegistrationInstanceID string `json:"registration_instance_id,omitempty"` + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Version string `json:"version,omitempty"` + StartedAt jstime.Datetime `json:"started_at,omitempty"` + Uptime *jstime.Duration `json:"uptime,omitempty"` + Username string `json:"username,omitempty"` + Region string `json:"region,omitempty"` + Configuration *config.AppConfig `json:"configuration,omitempty"` + HealthReportInterval int `json:"health_report_interval,omitempty"` +} + +type _NewUUID func() uuid.UUID + +type _Now func() time.Time + +func PerformRegistration(appConfig *config.AppConfig, ch Channels) (*Integration, error) { + defer closeChannels(ch) + + _, err := awaitVersion(appConfig.AnchoreDetails, ch, -1, 2*time.Second, 1*time.Hour) + if err != nil { + return nil, err + } + + registrationInfo := getRegistrationInfo(appConfig, uuid.New, time.Now) + + // Register this agent with enterprise + registeredIntegration, err := register(registrationInfo, appConfig.AnchoreDetails, -1, + 2*time.Second, 10*time.Minute, time.Now) + if err != nil { + logger.Log.Errorf("Unable to register agent: %v", err) + return nil, err + } + + enableHealthReporting(ch, registeredIntegration) + + if !inventoryReportingActive { + enableInventoryReporting(ch) + } + + return registeredIntegration, nil +} + +func awaitVersion(anchoreDetails connection.AnchoreInfo, ch Channels, maxRetry int, startBackoff, maxBackoff time.Duration) (*anchore.Version, error) { + attempt := 0 + for { + retry := false + + anchoreVersion, err := anchore.GetVersion(anchoreDetails) + if err == nil { + ver, vErr := version.NewVersion(anchoreVersion.Service.Version) + if vErr != nil { + logger.Log.Infof("Failed to parse received service version: %v. Will try again in %s", vErr, startBackoff) + retry = true + } else { + logger.Log.Infof("Successfully determined service version: %s for Enterprise: %s", + anchoreVersion.Service.Version, anchoreDetails.URL) + if ver.GreaterThanOrEqual(requiredAnchoreVersion) { + logger.Log.Infof("Proceeding with integration registration since Enterprise v%s supports that", anchoreVersion.Service.Version) + return anchoreVersion, nil + } + if !inventoryReportingActive { + logger.Log.Infof("Proceeding without integration registration and health reporting since Enterprise v%s does not support that", + anchoreVersion.Service.Version) + enableInventoryReporting(ch) + } + retry = true + } + } + + attempt++ + if maxRetry >= 0 && attempt > maxRetry { + logger.Log.Infof("Failed to get Enterprise version after %d attempts", attempt) + return nil, fmt.Errorf("failed to get Enterprise version after %d attempts", attempt) + } + + if anchore.ServerIsOffline(err) { + logger.Log.Infof("Anchore is offline. Will try again in %s", startBackoff) + retry = true + } + + if retry { + time.Sleep(startBackoff) + if startBackoff < maxBackoff { + startBackoff = min(startBackoff*2, maxBackoff) + } + continue + } + + logger.Log.Errorf("Failed to get service version for Enterprise: %s, %v", anchoreDetails.URL, err) + return nil, err + } +} + +func GetChannels() Channels { + return Channels{ + IntegrationObj: make(chan *Integration), + HealthReportingEnabled: make(chan bool, 1), // buffered to prevent registration from blocking + InventoryReportingEnabled: make(chan bool), + } +} + +func closeChannels(ch Channels) { + close(ch.IntegrationObj) + close(ch.HealthReportingEnabled) + close(ch.InventoryReportingEnabled) +} + +func enableHealthReporting(ch Channels, integration *Integration) { + logger.Log.Info("Activating health reporting") + // signal health reporting to start by providing it with the integration + ch.IntegrationObj <- integration + // signal inventory reporting to populate health report info when generating inventory reports + ch.HealthReportingEnabled <- true +} + +func enableInventoryReporting(ch Channels) { + inventoryReportingActive = true + logger.Log.Info("Activating inventory reporting") + // signal inventory reporting to start + ch.InventoryReportingEnabled <- true +} + +func register(registrationInfo *Registration, anchoreDetails connection.AnchoreInfo, maxRetry int, + startBackoff, maxBackoff time.Duration, now _Now) (*Integration, error) { + var err error + + attempt := 0 + for { + var registeredIntegration *Integration + + registeredIntegration, err = doRegister(registrationInfo, anchoreDetails, now) + if err == nil { + logger.Log.Infof("Successfully registered %s agent: %s (registration_id:%s / registration_instance_id:%s) with %s", + registrationInfo.Type, registrationInfo.Name, registrationInfo.RegistrationID, + registrationInfo.RegistrationInstanceID, anchoreDetails.URL) + logger.Log.Infof("This agent's integration uuid is %s", registeredIntegration.UUID) + return registeredIntegration, nil + } + + attempt++ + if maxRetry >= 0 && attempt > maxRetry { + logger.Log.Errorf("Failed to register agent (registration_id:%s / registration_instance_id:%s) after %d attempts", + registrationInfo.RegistrationID, registrationInfo.RegistrationInstanceID, attempt) + return nil, fmt.Errorf("failed to register after %d attempts", attempt) + } + + if anchore.ServerIsOffline(err) { + logger.Log.Infof("Anchore is offline. Will try again in %s", startBackoff) + time.Sleep(startBackoff) + if startBackoff < maxBackoff { + startBackoff = min(startBackoff*2, maxBackoff) + } + continue + } + + if anchore.UserLacksAPIPrivileges(err) { + logger.Log.Errorf("Specified user lacks required privileges to register and send health reports %v", err) + return nil, err + } + + if anchore.IncorrectCredentials(err) { + logger.Log.Errorf("Failed to register due to invalid credentials (wrong username or password)") + return nil, err + } + + logger.Log.Errorf("Failed to register integration agent (registration_id:%s / registration_instance_id:%s): %v", + registrationInfo.RegistrationID, registrationInfo.RegistrationInstanceID, err) + return nil, err + } +} + +func doRegister(registrationInfo *Registration, anchoreDetails connection.AnchoreInfo, now _Now) (*Integration, error) { + logger.Log.Infof("Registering %s agent: %s (registration_id:%s / registration_instance_id:%s) with %s", + registrationInfo.Type, registrationInfo.Name, registrationInfo.RegistrationID, + registrationInfo.RegistrationInstanceID, anchoreDetails.URL) + + registrationInfo.Uptime = &jstime.Duration{Duration: now().UTC().Sub(registrationInfo.StartedAt.Time)} + requestBody, err := json.Marshal(registrationInfo) + if err != nil { + return nil, fmt.Errorf("failed to serialize integration registration as JSON: %w", err) + } + responseBody, err := anchore.Post(requestBody, "", RegisterAPIPathV2, anchoreDetails, "integration registration") + if err != nil { + return nil, err + } + registeredIntegration := Integration{} + err = json.Unmarshal(*responseBody, ®isteredIntegration) + return ®isteredIntegration, err +} + +func getRegistrationInfo(appConfig *config.AppConfig, newUUID _NewUUID, now _Now) *Registration { + registrationID := appConfig.Registration.RegistrationID + if registrationID == "" { + logger.Log.Debugf("The registration_id value is not set. Generating UUIDv4 to use as registration_id") + registrationID = newUUID().String() + } else { + logger.Log.Debugf("Using registration_id specified in config: %s", registrationID) + } + + registrationInstanceID := newUUID().String() + logger.Log.Debugf("Generated registration_instance_id: %s", registrationInstanceID) + + instanceName := appConfig.Registration.IntegrationName + if instanceName == "" { + instanceName = deriveIntegrationName(appConfig.Region) + } + description := appConfig.Registration.IntegrationDescription + + appVersion := ecsVersion.FromBuild().Version + if appVersion == "[not provided]" { + appVersion = "dev" + } + + logger.Log.Debugf("Integration registration_id: %s, registration_instance_id: %s, name: %s, description: %s", + registrationID, registrationInstanceID, instanceName, description) + + instance := Registration{ + RegistrationID: registrationID, + RegistrationInstanceID: registrationInstanceID, + Type: Type, + Name: instanceName, + Description: description, + Version: appVersion, + StartedAt: jstime.Datetime{Time: now().UTC()}, + Uptime: new(jstime.Duration), + Username: appConfig.AnchoreDetails.User, + Region: appConfig.Region, + Configuration: nil, + HealthReportInterval: appConfig.HealthReportIntervalSeconds, + } + return &instance +} + +// deriveIntegrationName builds a default integration name from the AWS account ID and region. +// Falls back to just the region if the account ID cannot be determined. +func deriveIntegrationName(region string) string { + ctx := context.Background() + optFns := []func(*awsconfig.LoadOptions) error{} + if region != "" { + optFns = append(optFns, awsconfig.WithRegion(region)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, optFns...) + if err != nil { + logger.Log.Debugf("Failed to load AWS config for integration name derivation: %v", err) + return fmt.Sprintf("ecs-inventory-%s", region) + } + + stsClient := sts.NewFromConfig(cfg) + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + logger.Log.Debugf("Failed to get AWS caller identity for integration name derivation: %v", err) + return fmt.Sprintf("ecs-inventory-%s", region) + } + + accountID := aws.ToString(identity.Account) + logger.Log.Infof("Derived integration name from AWS account %s in region %s", accountID, region) + return fmt.Sprintf("ecs-inventory-%s-%s", accountID, region) +} diff --git a/pkg/inventory/ecs.go b/pkg/inventory/ecs.go index b71a1c1..a8e6e98 100644 --- a/pkg/inventory/ecs.go +++ b/pkg/inventory/ecs.go @@ -1,23 +1,33 @@ package inventory import ( + "context" "fmt" "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/anchore/ecs-inventory/internal/logger" "github.com/anchore/ecs-inventory/internal/tracker" "github.com/anchore/ecs-inventory/pkg/reporter" ) -// Check if AWS are present, should be stored in ~/.aws/credentials -func checkAWSCredentials(sess *session.Session) error { - _, err := sess.Config.Credentials.Get() +// ECSClient defines the subset of the ECS API used by this package. +type ECSClient interface { + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) + DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) + ListTagsForResource(ctx context.Context, params *ecs.ListTagsForResourceInput, optFns ...func(*ecs.Options)) (*ecs.ListTagsForResourceOutput, error) +} + +// Check if AWS credentials are present +func checkAWSCredentials(ctx context.Context, cfg aws.Config) error { + _, err := cfg.Credentials.Retrieve(ctx) if err != nil { fmt.Println( "Unable to get AWS credentials, please check ~/.aws/credentials file or environment variables are set correctly.", @@ -27,11 +37,11 @@ func checkAWSCredentials(sess *session.Session) error { return nil } -func fetchClusters(client ecsiface.ECSAPI) ([]*string, error) { +func fetchClusters(ctx context.Context, client ECSClient) ([]string, error) { defer tracker.TrackFunctionTime(time.Now(), "Fetching list of clusters") input := &ecs.ListClustersInput{} - result, err := client.ListClusters(input) + result, err := client.ListClusters(ctx, input) if err != nil { return nil, err } @@ -39,13 +49,13 @@ func fetchClusters(client ecsiface.ECSAPI) ([]*string, error) { return result.ClusterArns, nil } -func fetchTasksFromCluster(client ecsiface.ECSAPI, cluster string) ([]*string, error) { +func fetchTasksFromCluster(ctx context.Context, client ECSClient, cluster string) ([]string, error) { defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Fetching tasks from cluster: %s", cluster)) input := &ecs.ListTasksInput{ Cluster: aws.String(cluster), } - result, err := client.ListTasks(input) + result, err := client.ListTasks(ctx, input) if err != nil { return nil, err } @@ -53,13 +63,13 @@ func fetchTasksFromCluster(client ecsiface.ECSAPI, cluster string) ([]*string, e return result.TaskArns, nil } -func fetchServicesFromCluster(client ecsiface.ECSAPI, cluster string) ([]*string, error) { +func fetchServicesFromCluster(ctx context.Context, client ECSClient, cluster string) ([]string, error) { defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Fetching services from cluster: %s", cluster)) input := &ecs.ListServicesInput{ Cluster: aws.String(cluster), } - result, err := client.ListServices(input) + result, err := client.ListServices(ctx, input) if err != nil { return nil, err } @@ -67,14 +77,14 @@ func fetchServicesFromCluster(client ecsiface.ECSAPI, cluster string) ([]*string return result.ServiceArns, nil } -func fetchContainersFromTasks(client ecsiface.ECSAPI, cluster string, tasks []*string) ([]reporter.Container, error) { +func fetchContainersFromTasks(ctx context.Context, client ECSClient, cluster string, tasks []string) ([]reporter.Container, error) { defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Fetching Containers from tasks for cluster: %s", cluster)) input := &ecs.DescribeTasksInput{ Cluster: aws.String(cluster), Tasks: tasks, } - results, err := client.DescribeTasks(input) + results, err := client.DescribeTasks(ctx, input) if err != nil { return nil, err } @@ -102,7 +112,7 @@ func fetchContainersFromTasks(client ecsiface.ECSAPI, cluster string, tasks []*s return containers, nil } -func getContainerImageTag(containerTagMap map[string]string, container *ecs.Container) string { +func getContainerImageTag(containerTagMap map[string]string, container types.Container) string { // Fix container image tag if it contains an @ symbol if strings.Contains(*container.Image, "@") { // replace the image tag with the correct one @@ -116,7 +126,7 @@ func getContainerImageTag(containerTagMap map[string]string, container *ecs.Cont } // Build a map of container image digests to image tags -func buildContainerTagMap(tasks []*ecs.Task) map[string]string { +func buildContainerTagMap(tasks []types.Task) map[string]string { containerMap := make(map[string]string) for _, task := range tasks { for _, container := range task.Containers { @@ -155,13 +165,13 @@ func constructServiceARN(clusterARN string, serviceName string) (string, error) return fmt.Sprintf("arn:aws:ecs:%s:%s:service/%s/%s", region, accountID, clusterName, serviceName), nil } -func fetchTasksMetadata(client ecsiface.ECSAPI, cluster string, tasks []*string) ([]reporter.Task, error) { +func fetchTasksMetadata(ctx context.Context, client ECSClient, cluster string, tasks []string) ([]reporter.Task, error) { input := &ecs.DescribeTasksInput{ Cluster: aws.String(cluster), Tasks: tasks, } - results, err := client.DescribeTasks(input) + results, err := client.DescribeTasks(ctx, input) if err != nil { return nil, err } @@ -169,7 +179,7 @@ func fetchTasksMetadata(client ecsiface.ECSAPI, cluster string, tasks []*string) var tasksMetadata []reporter.Task for _, task := range results.Tasks { // Tags may not be present in the task response so we need to fetch them explicitly - tagMap, err := fetchTagsForResource(client, *task.TaskArn) + tagMap, err := fetchTagsForResource(ctx, client, *task.TaskArn) if err != nil { return nil, err } @@ -202,13 +212,13 @@ func fetchTasksMetadata(client ecsiface.ECSAPI, cluster string, tasks []*string) return tasksMetadata, nil } -func fetchServicesMetadata(client ecsiface.ECSAPI, cluster string, services []*string) ([]reporter.Service, error) { +func fetchServicesMetadata(ctx context.Context, client ECSClient, cluster string, services []string) ([]reporter.Service, error) { input := &ecs.DescribeServicesInput{ Cluster: aws.String(cluster), Services: services, } - results, err := client.DescribeServices(input) + results, err := client.DescribeServices(ctx, input) if err != nil { return nil, err } @@ -216,7 +226,7 @@ func fetchServicesMetadata(client ecsiface.ECSAPI, cluster string, services []*s var servicesMetadata []reporter.Service for _, service := range results.Services { // Tags may not be present in the service response so we need to fetch them explicitly - tagMap, err := fetchTagsForResource(client, *service.ServiceArn) + tagMap, err := fetchTagsForResource(ctx, client, *service.ServiceArn) if err != nil { return nil, err } @@ -230,12 +240,12 @@ func fetchServicesMetadata(client ecsiface.ECSAPI, cluster string, services []*s return servicesMetadata, nil } -func fetchTagsForResource(client ecsiface.ECSAPI, resourceARN string) (map[string]string, error) { +func fetchTagsForResource(ctx context.Context, client ECSClient, resourceARN string) (map[string]string, error) { input := &ecs.ListTagsForResourceInput{ ResourceArn: aws.String(resourceARN), } - result, err := client.ListTagsForResource(input) + result, err := client.ListTagsForResource(ctx, input) if err != nil { return nil, err } diff --git a/pkg/inventory/ecs_test.go b/pkg/inventory/ecs_test.go index 01249d6..5e46bc4 100644 --- a/pkg/inventory/ecs_test.go +++ b/pkg/inventory/ecs_test.go @@ -1,26 +1,24 @@ package inventory import ( + "context" "testing" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/stretchr/testify/assert" "github.com/anchore/ecs-inventory/pkg/reporter" ) -// Return a pointer to the passed value -func GetPointerToValue[T any](t T) *T { return &t } - func Test_fetchClusters(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient } tests := []struct { name string args args - want []*string + want []string wantErr bool }{ { @@ -37,15 +35,15 @@ func Test_fetchClusters(t *testing.T) { args: args{ client: &mockECSClient{}, }, - want: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:cluster/cluster-2"), + want: []string{ + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1", + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-2", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchClusters(tt.args.client) + got, err := fetchClusters(context.Background(), tt.args.client) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -56,13 +54,13 @@ func Test_fetchClusters(t *testing.T) { func Test_fetchTasksFromCluster(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient cluster string } tests := []struct { name string args args - want []*string + want []string wantErr bool }{ { @@ -80,15 +78,15 @@ func Test_fetchTasksFromCluster(t *testing.T) { client: &mockECSClient{}, cluster: "cluster-1", }, - want: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111"), + want: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchTasksFromCluster(tt.args.client, tt.args.cluster) + got, err := fetchTasksFromCluster(context.Background(), tt.args.client, tt.args.cluster) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -99,9 +97,9 @@ func Test_fetchTasksFromCluster(t *testing.T) { func Test_fetchContainersFromTasks(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient cluster string - tasks []*string + tasks []string } tests := []struct { name string @@ -116,8 +114,8 @@ func Test_fetchContainersFromTasks(t *testing.T) { ErrorOnDescribeTasks: true, }, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("BAD-ARN"), + tasks: []string{ + "BAD-ARN", }, }, wantErr: true, @@ -127,8 +125,8 @@ func Test_fetchContainersFromTasks(t *testing.T) { args: args{ client: &mockECSClient{}, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), + tasks: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", }, }, want: []reporter.Container{ @@ -151,9 +149,9 @@ func Test_fetchContainersFromTasks(t *testing.T) { args: args{ client: &mockECSClient{}, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111"), + tasks: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111", }, }, want: []reporter.Container{ @@ -186,7 +184,7 @@ func Test_fetchContainersFromTasks(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchContainersFromTasks(tt.args.client, tt.args.cluster, tt.args.tasks) + got, err := fetchContainersFromTasks(context.Background(), tt.args.client, tt.args.cluster, tt.args.tasks) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -197,9 +195,9 @@ func Test_fetchContainersFromTasks(t *testing.T) { func Test_fetchTasksMetadata(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient cluster string - tasks []*string + tasks []string } tests := []struct { name string @@ -214,8 +212,8 @@ func Test_fetchTasksMetadata(t *testing.T) { ErrorOnDescribeTasks: true, }, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), + tasks: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", }, }, wantErr: true, @@ -227,8 +225,8 @@ func Test_fetchTasksMetadata(t *testing.T) { ErrorOnListTagsForResource: true, }, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), + tasks: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", }, }, wantErr: true, @@ -238,9 +236,9 @@ func Test_fetchTasksMetadata(t *testing.T) { args: args{ client: &mockECSClient{}, cluster: "cluster-1", - tasks: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111"), + tasks: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111", }, }, want: []reporter.Task{ @@ -264,7 +262,7 @@ func Test_fetchTasksMetadata(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchTasksMetadata(tt.args.client, tt.args.cluster, tt.args.tasks) + got, err := fetchTasksMetadata(context.Background(), tt.args.client, tt.args.cluster, tt.args.tasks) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -275,7 +273,7 @@ func Test_fetchTasksMetadata(t *testing.T) { func Test_fetchTagsForResource(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient resourceARN string } tests := []struct { @@ -316,7 +314,7 @@ func Test_fetchTagsForResource(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchTagsForResource(tt.args.client, tt.args.resourceARN) + got, err := fetchTagsForResource(context.Background(), tt.args.client, tt.args.resourceARN) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -327,13 +325,13 @@ func Test_fetchTagsForResource(t *testing.T) { func Test_fetchServicesFromCluster(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient cluster string } tests := []struct { name string args args - want []*string + want []string wantErr bool }{ { @@ -352,15 +350,15 @@ func Test_fetchServicesFromCluster(t *testing.T) { client: &mockECSClient{}, cluster: "cluster-1", }, - want: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2"), + want: []string{ + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1", + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchServicesFromCluster(tt.args.client, tt.args.cluster) + got, err := fetchServicesFromCluster(context.Background(), tt.args.client, tt.args.cluster) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -371,9 +369,9 @@ func Test_fetchServicesFromCluster(t *testing.T) { func Test_fetchServicesMetadata(t *testing.T) { type args struct { - client ecsiface.ECSAPI + client ECSClient cluster string - services []*string + services []string } tests := []struct { name string @@ -388,8 +386,8 @@ func Test_fetchServicesMetadata(t *testing.T) { ErrorOnDescribeServices: true, }, cluster: "cluster-1", - services: []*string{ - GetPointerToValue("arn"), + services: []string{ + "arn", }, }, wantErr: true, @@ -401,8 +399,8 @@ func Test_fetchServicesMetadata(t *testing.T) { ErrorOnListTagsForResource: true, }, cluster: "cluster-1", - services: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1"), + services: []string{ + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1", }, }, wantErr: true, @@ -412,9 +410,9 @@ func Test_fetchServicesMetadata(t *testing.T) { args: args{ client: &mockECSClient{}, cluster: "cluster-1", - services: []*string{ - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1"), - GetPointerToValue("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2"), + services: []string{ + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1", + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2", }, }, want: []reporter.Service{ @@ -434,7 +432,7 @@ func Test_fetchServicesMetadata(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fetchServicesMetadata(tt.args.client, tt.args.cluster, tt.args.services) + got, err := fetchServicesMetadata(context.Background(), tt.args.client, tt.args.cluster, tt.args.services) if (err != nil) != tt.wantErr { assert.Error(t, err) } @@ -501,7 +499,7 @@ func Test_constructServiceARN(t *testing.T) { func Test_getContainerImageTag(t *testing.T) { type args struct { containerTagMap map[string]string - container *ecs.Container + container types.Container } tests := []struct { name string @@ -515,9 +513,9 @@ func Test_getContainerImageTag(t *testing.T) { "sha256:1234567890123456789012345678901234567890123456789012345678901111": "image-1:latest", "sha256:1234567890123456789012345678901234567890123456789012345678902222": "image-2:latest", }, - container: &ecs.Container{ - Image: GetPointerToValue("image-1:latest"), - ImageDigest: GetPointerToValue("sha256:1234567890123456789012345678901234567890123456789012345678901111"), + container: types.Container{ + Image: aws.String("image-1:latest"), + ImageDigest: aws.String("sha256:1234567890123456789012345678901234567890123456789012345678901111"), }, }, want: "image-1:latest", @@ -529,9 +527,9 @@ func Test_getContainerImageTag(t *testing.T) { "sha256:1234567890123456789012345678901234567890123456789012345678901111": "image-1:latest", "sha256:1234567890123456789012345678901234567890123456789012345678902222": "image-2:latest", }, - container: &ecs.Container{ - Image: GetPointerToValue("image-1@sha256:1234567890123456789012345678901234567890123456789012345678901111"), - ImageDigest: GetPointerToValue("sha256:1234567890123456789012345678901234567890123456789012345678901111"), + container: types.Container{ + Image: aws.String("image-1@sha256:1234567890123456789012345678901234567890123456789012345678901111"), + ImageDigest: aws.String("sha256:1234567890123456789012345678901234567890123456789012345678901111"), }, }, want: "image-1:latest", @@ -543,9 +541,9 @@ func Test_getContainerImageTag(t *testing.T) { "sha256:1234567890123456789012345678901234567890123456789012345678901111": "image-1:latest", "sha256:1234567890123456789012345678901234567890123456789012345678902222": "image-2:latest", }, - container: &ecs.Container{ - Image: GetPointerToValue("image-1@sha256:0000"), - ImageDigest: GetPointerToValue("sha256:11"), + container: types.Container{ + Image: aws.String("image-1@sha256:0000"), + ImageDigest: aws.String("sha256:11"), }, }, want: "image-1:UNKNOWN", diff --git a/pkg/inventory/mock_ecs_test.go b/pkg/inventory/mock_ecs_test.go index 35b7626..8f68527 100644 --- a/pkg/inventory/mock_ecs_test.go +++ b/pkg/inventory/mock_ecs_test.go @@ -1,15 +1,15 @@ package inventory import ( + "context" "errors" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" ) type mockECSClient struct { - ecsiface.ECSAPI ErrorOnListCluster bool ErrorOnListTasks bool ErrorOnListServices bool @@ -18,51 +18,51 @@ type mockECSClient struct { ErrorOnDescribeServices bool } -func (m *mockECSClient) ListClusters(*ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { +func (m *mockECSClient) ListClusters(_ context.Context, _ *ecs.ListClustersInput, _ ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) { if m.ErrorOnListCluster { return nil, errors.New("list cluster error") } return &ecs.ListClustersOutput{ - ClusterArns: []*string{ - aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1"), - aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster-2"), + ClusterArns: []string{ + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1", + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-2", }, }, nil } -func (m *mockECSClient) ListTasks(*ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { +func (m *mockECSClient) ListTasks(_ context.Context, _ *ecs.ListTasksInput, _ ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { if m.ErrorOnListTasks { return nil, errors.New("list tasks error") } return &ecs.ListTasksOutput{ - TaskArns: []*string{ - aws.String("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000"), - aws.String("arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111"), + TaskArns: []string{ + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", + "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111", }, }, nil } -func (m *mockECSClient) ListServices(*ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { +func (m *mockECSClient) ListServices(_ context.Context, _ *ecs.ListServicesInput, _ ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { if m.ErrorOnListServices { return nil, errors.New("list services error") } return &ecs.ListServicesOutput{ - ServiceArns: []*string{ - aws.String("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1"), - aws.String("arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2"), + ServiceArns: []string{ + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1", + "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2", }, }, nil } -func (m *mockECSClient) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) { +func (m *mockECSClient) DescribeTasks(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { if m.ErrorOnDescribeTasks { return nil, errors.New("describe tasks error") } - tasks := []*ecs.Task{} + tasks := []types.Task{} for _, t := range input.Tasks { - switch *t { + switch t { case "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000": - tasks = append(tasks, &ecs.Task{ + tasks = append(tasks, types.Task{ TaskArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000", ), @@ -71,7 +71,7 @@ func (m *mockECSClient) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.Descr "arn:aws:ecs:us-east-1:123456789012:task-definition/task-definition-1:1", ), Group: aws.String("service:service-1"), - Containers: []*ecs.Container{ + Containers: []types.Container{ { ContainerArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:container/12345678-1234-1234-1234-111111111111", @@ -91,7 +91,7 @@ func (m *mockECSClient) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.Descr }, }) case "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111": - tasks = append(tasks, &ecs.Task{ + tasks = append(tasks, types.Task{ TaskArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-111111111111", ), @@ -100,7 +100,7 @@ func (m *mockECSClient) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.Descr "arn:aws:ecs:us-east-1:123456789012:task-definition/task-definition-1:1", ), Group: aws.String("service:service-1"), - Containers: []*ecs.Container{ + Containers: []types.Container{ { ContainerArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:container/12345678-1234-1234-1234-111111111113", @@ -125,14 +125,14 @@ func (m *mockECSClient) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.Descr return &ecs.DescribeTasksOutput{Tasks: tasks}, nil } -func (m *mockECSClient) ListTagsForResource(input *ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error) { +func (m *mockECSClient) ListTagsForResource(_ context.Context, input *ecs.ListTagsForResourceInput, _ ...func(*ecs.Options)) (*ecs.ListTagsForResourceOutput, error) { if m.ErrorOnListTagsForResource { return nil, errors.New("list tags for resource error") } switch *input.ResourceArn { case "arn:aws:ecs:us-east-1:123456789012:task/cluster-1/12345678-1234-1234-1234-000000000000": return &ecs.ListTagsForResourceOutput{ - Tags: []*ecs.Tag{ + Tags: []types.Tag{ { Key: aws.String("key-1"), Value: aws.String("value-1"), @@ -145,7 +145,7 @@ func (m *mockECSClient) ListTagsForResource(input *ecs.ListTagsForResourceInput) }, nil case "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1": return &ecs.ListTagsForResourceOutput{ - Tags: []*ecs.Tag{ + Tags: []types.Tag{ { Key: aws.String("svc-key-1"), Value: aws.String("svc-value-1"), @@ -161,23 +161,23 @@ func (m *mockECSClient) ListTagsForResource(input *ecs.ListTagsForResourceInput) } } -func (m *mockECSClient) DescribeServices(input *ecs.DescribeServicesInput) (*ecs.DescribeServicesOutput, error) { +func (m *mockECSClient) DescribeServices(_ context.Context, input *ecs.DescribeServicesInput, _ ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { if m.ErrorOnDescribeServices { return nil, errors.New("describe services error") } - services := []*ecs.Service{} + services := []types.Service{} for _, s := range input.Services { - switch *s { + switch s { case "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1": - services = append(services, &ecs.Service{ + services = append(services, types.Service{ ServiceArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-1", ), ClusterArn: aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1"), }) case "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2": - services = append(services, &ecs.Service{ + services = append(services, types.Service{ ServiceArn: aws.String( "arn:aws:ecs:us-east-1:123456789012:service/cluster-1/service-2", ), diff --git a/pkg/inventory/report.go b/pkg/inventory/report.go index 6a87e3e..211b412 100644 --- a/pkg/inventory/report.go +++ b/pkg/inventory/report.go @@ -1,16 +1,15 @@ package inventory import ( + "context" "encoding/json" "fmt" "os" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/anchore/ecs-inventory/internal/logger" "github.com/anchore/ecs-inventory/internal/tracker" @@ -50,23 +49,26 @@ func HandleReport(report reporter.Report, anchoreDetails connection.AnchoreInfo, func GetInventoryReportsForRegion(region string, anchoreDetails connection.AnchoreInfo, quiet, dryRun bool) error { defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Getting Inventory Reports for region: %s", region)) logger.Log.Info("Getting Inventory Reports for region", "region", region) - sessConfig := &aws.Config{} + + ctx := context.Background() + optFns := []func(*config.LoadOptions) error{} if region != "" { - sessConfig.Region = aws.String(region) + optFns = append(optFns, config.WithRegion(region)) } - sess, err := session.NewSession(sessConfig) + cfg, err := config.LoadDefaultConfig(ctx, optFns...) if err != nil { - logger.Log.Error("Failed to create AWS session", err) + logger.Log.Error("Failed to load AWS config", err) + return fmt.Errorf("failed to load AWS config: %w", err) } - err = checkAWSCredentials(sess) + err = checkAWSCredentials(ctx, cfg) if err != nil { return err } - ecsClient := ecs.New(sess) + ecsClient := ecs.NewFromConfig(cfg) - clusters, err := fetchClusters(ecsClient) + clusters, err := fetchClusters(ctx, ecsClient) if err != nil { return err } @@ -78,9 +80,7 @@ func GetInventoryReportsForRegion(region string, anchoreDetails connection.Ancho go func(cluster string) { defer wg.Done() - ecsClient := ecs.New(sess) - - report, err := GetInventoryReportForCluster(cluster, ecsClient) + report, err := GetInventoryReportForCluster(ctx, cluster, ecsClient) if err != nil { logger.Log.Error("Failed to get inventory report for cluster", err) } @@ -94,7 +94,7 @@ func GetInventoryReportsForRegion(region string, anchoreDetails connection.Ancho logger.Log.Error("Failed payload", fmt.Errorf("report %s", jsonReport)) } } - }(*cluster) + }(cluster) } wg.Wait() @@ -174,7 +174,7 @@ func ensureReferencedObjectsExist(report reporter.Report) reporter.Report { } // GetInventoryReportForCluster is an atomic method for getting in-use image results, for a cluster -func GetInventoryReportForCluster(clusterARN string, ecsClient ecsiface.ECSAPI) (reporter.Report, error) { +func GetInventoryReportForCluster(ctx context.Context, clusterARN string, ecsClient ECSClient) (reporter.Report, error) { defer tracker.TrackFunctionTime(time.Now(), fmt.Sprintf("Getting Inventory Report for cluster: %s", clusterARN)) logger.Log.Debug("Found cluster", "cluster", clusterARN) @@ -182,20 +182,20 @@ func GetInventoryReportForCluster(clusterARN string, ecsClient ecsiface.ECSAPI) Timestamp: time.Now().UTC().Format(time.RFC3339), ClusterARN: clusterARN, } - tasks, err := fetchTasksFromCluster(ecsClient, clusterARN) + tasks, err := fetchTasksFromCluster(ctx, ecsClient, clusterARN) if err != nil { return reporter.Report{}, err } servicesMeta := []reporter.Service{} - services, err := fetchServicesFromCluster(ecsClient, clusterARN) + services, err := fetchServicesFromCluster(ctx, ecsClient, clusterARN) if err != nil { return reporter.Report{}, err } if len(services) == 0 { logger.Log.Debug("No services found in cluster", "cluster", clusterARN) } else { - servicesMeta, err = fetchServicesMetadata(ecsClient, clusterARN, services) + servicesMeta, err = fetchServicesMetadata(ctx, ecsClient, clusterARN, services) if err != nil { return reporter.Report{}, err } @@ -208,13 +208,13 @@ func GetInventoryReportForCluster(clusterARN string, ecsClient ecsiface.ECSAPI) } else { logger.Log.Debug("Found tasks in cluster", "cluster", clusterARN, "taskCount", len(tasks)) - taskMeta, err := fetchTasksMetadata(ecsClient, clusterARN, tasks) + taskMeta, err := fetchTasksMetadata(ctx, ecsClient, clusterARN, tasks) if err != nil { return reporter.Report{}, err } report.Tasks = taskMeta - containers, err := fetchContainersFromTasks(ecsClient, clusterARN, tasks) + containers, err := fetchContainersFromTasks(ctx, ecsClient, clusterARN, tasks) if err != nil { return reporter.Report{}, err } diff --git a/pkg/inventory/report_test.go b/pkg/inventory/report_test.go index abc9793..8f3c86f 100644 --- a/pkg/inventory/report_test.go +++ b/pkg/inventory/report_test.go @@ -1,6 +1,7 @@ package inventory import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -11,7 +12,7 @@ import ( func TestGetInventoryReportForCluster(t *testing.T) { mockSvc := &mockECSClient{} - report, err := GetInventoryReportForCluster("cluster-1", mockSvc) + report, err := GetInventoryReportForCluster(context.Background(), "cluster-1", mockSvc) assert.NoError(t, err) assert.Equal(t, 4, len(report.Containers)) diff --git a/pkg/lib.go b/pkg/lib.go index 31b819c..4c2d66a 100644 --- a/pkg/lib.go +++ b/pkg/lib.go @@ -3,16 +3,80 @@ package pkg import ( "time" + "github.com/anchore/ecs-inventory/internal/config" + "github.com/anchore/ecs-inventory/internal/logger" + jstime "github.com/anchore/ecs-inventory/internal/time" "github.com/anchore/ecs-inventory/pkg/connection" + "github.com/anchore/ecs-inventory/pkg/healthreporter" + "github.com/anchore/ecs-inventory/pkg/integration" "github.com/anchore/ecs-inventory/pkg/inventory" - "github.com/anchore/ecs-inventory/pkg/logger" + pkgLogger "github.com/anchore/ecs-inventory/pkg/logger" ) -var log logger.Logger +var log pkgLogger.Logger -// PeriodicallyGetInventoryReport periodically retrieve image results and report/output them according to the configuration. -// Note: Errors do not cause the function to exit, since this is periodically running +// PeriodicallyGetInventoryReport periodically retrieves image results with channel-based coordination +// for health reporting integration. Waits for registration to complete before starting. +// Note: Errors do not cause the function to exit, since this is periodically running. func PeriodicallyGetInventoryReport( + cfg *config.AppConfig, + ch integration.Channels, + gatedReportInfo *healthreporter.GatedReportInfo, +) { + // Wait for registration with Enterprise to be disabled or completed + <-ch.InventoryReportingEnabled + logger.Log.Info("Inventory reporting started") + healthReportingEnabled := false + + // Fire off a ticker that reports according to a configurable polling interval + ticker := time.NewTicker(time.Duration(cfg.PollingIntervalSeconds) * time.Second) + + for { + reportTimestamp := time.Now().UTC().Format(time.RFC3339) + err := inventory.GetInventoryReportsForRegion(cfg.Region, cfg.AnchoreDetails, cfg.Quiet, cfg.DryRun) + if err != nil { + logger.Log.Error("Failed to get Inventory Reports for region", err) + } else { + // Track batch info for health reporting + reportInfo := healthreporter.InventoryReportInfo{ + Account: cfg.AnchoreDetails.Account, + Region: cfg.Region, + BatchSize: 1, + LastSuccessfulIndex: 1, + Batches: make([]healthreporter.BatchInfo, 0), + HasErrors: false, + ReportTimestamp: reportTimestamp, + } + batchInfo := healthreporter.BatchInfo{ + SendTimestamp: jstime.Datetime{Time: time.Now().UTC()}, + BatchIndex: 1, + } + reportInfo.Batches = append(reportInfo.Batches, batchInfo) + + select { + case isEnabled, isNotClosed := <-ch.HealthReportingEnabled: + if isNotClosed { + healthReportingEnabled = isEnabled + } + logger.Log.Infof("Health reporting enabled: %t", healthReportingEnabled) + default: + } + if healthReportingEnabled { + healthreporter.SetReportInfoNoBlocking(cfg.AnchoreDetails.Account, 0, reportInfo, gatedReportInfo) + } + } + + logger.Log.Infof("Waiting %d seconds for next poll...", cfg.PollingIntervalSeconds) + + // Wait at least as long as the ticker + logger.Log.Debugf("Start new gather: %s", <-ticker.C) + } +} + +// PeriodicallyGetInventoryReportSimple is the simple polling loop used when Anchore details +// are not configured (no health reporting or registration). +// Note: Errors do not cause the function to exit, since this is periodically running. +func PeriodicallyGetInventoryReportSimple( pollingIntervalSeconds int, anchoreDetails connection.AnchoreInfo, region string, @@ -24,14 +88,14 @@ func PeriodicallyGetInventoryReport( for { err := inventory.GetInventoryReportsForRegion(region, anchoreDetails, quiet, dryRun) if err != nil { - log.Error("Failed to get Inventory Reports for region", err) + logger.Log.Error("Failed to get Inventory Reports for region", err) } // Wait at least as long as the ticker - log.Debugf("Start new gather %s", <-ticker.C) + logger.Log.Debugf("Start new gather %s", <-ticker.C) } } -func SetLogger(logger logger.Logger) { - log = logger +func SetLogger(l pkgLogger.Logger) { + log = l } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 24b45ba..17dcd24 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -2,9 +2,11 @@ package logger type Logger interface { Error(msg string, err error, args ...interface{}) + Errorf(msg string, args ...interface{}) Warn(msg string, args ...interface{}) Warnf(msg string, args ...interface{}) Info(msg string, args ...interface{}) + Infof(msg string, args ...interface{}) Debug(msg string, args ...interface{}) Debugf(msg string, args ...interface{}) }