diff --git a/go.work b/go.work index 6ebafbdb..842dc1c6 100644 --- a/go.work +++ b/go.work @@ -10,6 +10,7 @@ use ( ./tools/prow-job-executor ./tools/registration ./tools/release + ./tools/image-mirror ./tools/secret-sync ./tools/yamlwrap ) diff --git a/tools/image-mirror/command.go b/tools/image-mirror/command.go new file mode 100644 index 00000000..051b5af1 --- /dev/null +++ b/tools/image-mirror/command.go @@ -0,0 +1,31 @@ +package imagemirror + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/Azure/ARO-Tools/tools/image-mirror/sync" +) + +func NewCommand() (*cobra.Command, error) { + cmd := &cobra.Command{ + Use: "image-mirror", + Short: "Mirror container images to Azure Container Registry.", + SilenceUsage: true, + SilenceErrors: true, + } + + commands := []func() (*cobra.Command, error){ + sync.NewCommand, + } + for _, newCmd := range commands { + c, err := newCmd() + if err != nil { + return nil, fmt.Errorf("failed to create subcommand: %w", err) + } + cmd.AddCommand(c) + } + + return cmd, nil +} diff --git a/tools/image-mirror/command_test.go b/tools/image-mirror/command_test.go new file mode 100644 index 00000000..87b706f3 --- /dev/null +++ b/tools/image-mirror/command_test.go @@ -0,0 +1,27 @@ +package imagemirror + +import ( + "testing" +) + +func TestNewCommand(t *testing.T) { + cmd, err := NewCommand() + if err != nil { + t.Fatalf("unexpected error creating command: %v", err) + } + if cmd.Use != "image-mirror" { + t.Errorf("expected Use 'image-mirror', got %q", cmd.Use) + } + + // Verify the "sync" subcommand is registered + found := false + for _, sub := range cmd.Commands() { + if sub.Use == "sync" { + found = true + break + } + } + if !found { + t.Error("expected 'sync' subcommand not found") + } +} diff --git a/tools/image-mirror/go.mod b/tools/image-mirror/go.mod new file mode 100644 index 00000000..91bb69e8 --- /dev/null +++ b/tools/image-mirror/go.mod @@ -0,0 +1,13 @@ +module github.com/Azure/ARO-Tools/tools/image-mirror + +go 1.25.0 + +require ( + github.com/go-logr/logr v1.4.3 + github.com/spf13/cobra v1.10.2 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect +) diff --git a/tools/image-mirror/go.sum b/tools/image-mirror/go.sum new file mode 100644 index 00000000..5ef5b04d --- /dev/null +++ b/tools/image-mirror/go.sum @@ -0,0 +1,13 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/tools/image-mirror/sync/command.go b/tools/image-mirror/sync/command.go new file mode 100644 index 00000000..2b5901db --- /dev/null +++ b/tools/image-mirror/sync/command.go @@ -0,0 +1,56 @@ +package sync + +import ( + "fmt" + "log" + "os" + "os/signal" + + "github.com/go-logr/logr" + "github.com/go-logr/logr/funcr" + "github.com/spf13/cobra" +) + +// NewCommand creates the "sync" subcommand. +func NewCommand() (*cobra.Command, error) { + cmd := &cobra.Command{ + Use: "sync", + Short: "Sync a container image to an Azure Container Registry.", + SilenceUsage: true, + SilenceErrors: true, + } + + opts := DefaultOptions() + BindOptions(opts, cmd) + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt) + defer cancel() + + validated, err := opts.Validate() + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + completed, err := validated.Complete() + if err != nil { + return fmt.Errorf("completion failed: %w", err) + } + + logger := newLogger() + runner := completed.NewRunner(logger) + return runner.Run(ctx) + } + + return cmd, nil +} + +func newLogger() logr.Logger { + return funcr.New(func(prefix, args string) { + if prefix != "" { + log.Printf("%s: %s", prefix, args) + } else { + log.Print(args) + } + }, funcr.Options{ + Verbosity: 1, + }) +} diff --git a/tools/image-mirror/sync/command_test.go b/tools/image-mirror/sync/command_test.go new file mode 100644 index 00000000..f2b80376 --- /dev/null +++ b/tools/image-mirror/sync/command_test.go @@ -0,0 +1,37 @@ +package sync + +import ( + "testing" +) + +func TestNewCommand(t *testing.T) { + cmd, err := NewCommand() + if err != nil { + t.Fatalf("unexpected error creating command: %v", err) + } + if cmd.Use != "sync" { + t.Errorf("expected Use 'sync', got %q", cmd.Use) + } + + // Verify all expected flags exist + expectedFlags := []string{ + "target-acr", + "source-registry", + "repository", + "digest", + "copy-from", + "image-file-path", + "image-tar-file", + "image-metadata-file", + "image-tar-sas", + "image-metadata-sas", + "pull-secret-kv", + "pull-secret", + "dry-run", + } + for _, flag := range expectedFlags { + if cmd.Flags().Lookup(flag) == nil { + t.Errorf("expected flag %q not found", flag) + } + } +} diff --git a/tools/image-mirror/sync/options.go b/tools/image-mirror/sync/options.go new file mode 100644 index 00000000..e15d0ab5 --- /dev/null +++ b/tools/image-mirror/sync/options.go @@ -0,0 +1,146 @@ +package sync + +import ( + "fmt" + + "github.com/go-logr/logr" + "github.com/spf13/cobra" +) + +// RawOptions holds the raw CLI input values. +type RawOptions struct { + TargetACR string + SourceRegistry string + Repository string + Digest string + CopyFrom string + ImageFilePath string + ImageTarFileName string + ImageMetadataFileName string + ImageTarSAS string + ImageMetadataSAS string + PullSecretKV string + PullSecretName string + DryRun bool +} + +// validatedOptions is a private wrapper that enforces a call of Validate() before Complete() can be invoked. +type validatedOptions struct { + *RawOptions +} + +// ValidatedOptions wraps validatedOptions to enforce the Validate() -> Complete() flow. +type ValidatedOptions struct { + *validatedOptions +} + +// completedOptions holds the finalized, ready-to-use options. +type completedOptions struct { + TargetACR string + SourceRegistry string + Repository string + Digest string + CopyFrom string + ImageFilePath string + ImageTarFileName string + ImageMetadataFileName string + ImageTarSAS string + ImageMetadataSAS string + PullSecretKV string + PullSecretName string + DryRun bool +} + +// Options wraps completedOptions to enforce the Validate() -> Complete() -> Run() flow. +type Options struct { + *completedOptions +} + +// DefaultOptions returns a new RawOptions with defaults. +func DefaultOptions() *RawOptions { + return &RawOptions{} +} + +// BindOptions binds CLI flags to the raw options. +func BindOptions(opts *RawOptions, cmd *cobra.Command) { + cmd.Flags().StringVar(&opts.TargetACR, "target-acr", opts.TargetACR, "Target Azure Container Registry name.") + cmd.Flags().StringVar(&opts.SourceRegistry, "source-registry", opts.SourceRegistry, "Source container registry host (for registry copy mode).") + cmd.Flags().StringVar(&opts.Repository, "repository", opts.Repository, "Image repository name.") + cmd.Flags().StringVar(&opts.Digest, "digest", opts.Digest, "Image digest (e.g. sha256:...).") + cmd.Flags().StringVar(&opts.CopyFrom, "copy-from", opts.CopyFrom, "Copy mode: 'oci-layout' for file-based or empty for registry-based.") + cmd.Flags().StringVar(&opts.ImageFilePath, "image-file-path", opts.ImageFilePath, "Directory path containing image tar and metadata files.") + cmd.Flags().StringVar(&opts.ImageTarFileName, "image-tar-file", opts.ImageTarFileName, "Image tar file name for OCI layout mode.") + cmd.Flags().StringVar(&opts.ImageMetadataFileName, "image-metadata-file", opts.ImageMetadataFileName, "Image metadata JSON file name for OCI layout mode.") + cmd.Flags().StringVar(&opts.ImageTarSAS, "image-tar-sas", opts.ImageTarSAS, "SAS URL for downloading the image tar file.") + cmd.Flags().StringVar(&opts.ImageMetadataSAS, "image-metadata-sas", opts.ImageMetadataSAS, "SAS URL for downloading the image metadata file.") + cmd.Flags().StringVar(&opts.PullSecretKV, "pull-secret-kv", opts.PullSecretKV, "KeyVault name containing pull secret (for registry copy mode).") + cmd.Flags().StringVar(&opts.PullSecretName, "pull-secret", opts.PullSecretName, "Pull secret name in KeyVault (for registry copy mode).") + cmd.Flags().BoolVar(&opts.DryRun, "dry-run", opts.DryRun, "If true, validate inputs without making changes.") +} + +// Validate validates the raw options. +func (o *RawOptions) Validate() (*ValidatedOptions, error) { + if o.TargetACR == "" { + return nil, fmt.Errorf("the target ACR must be provided with --target-acr") + } + if o.Repository == "" { + return nil, fmt.Errorf("the repository must be provided with --repository") + } + + if o.CopyFrom == copyFromOCI { + if o.ImageTarFileName == "" { + return nil, fmt.Errorf("the image tar file name must be provided with --image-tar-file for oci-layout mode") + } + if o.ImageMetadataFileName == "" { + return nil, fmt.Errorf("the image metadata file name must be provided with --image-metadata-file for oci-layout mode") + } + } else { + if o.SourceRegistry == "" { + return nil, fmt.Errorf("the source registry must be provided with --source-registry for registry mode") + } + if o.Digest == "" { + return nil, fmt.Errorf("the digest must be provided with --digest for registry mode") + } + if o.PullSecretKV == "" { + return nil, fmt.Errorf("the pull secret KeyVault must be provided with --pull-secret-kv for registry mode") + } + if o.PullSecretName == "" { + return nil, fmt.Errorf("the pull secret name must be provided with --pull-secret for registry mode") + } + } + + return &ValidatedOptions{ + validatedOptions: &validatedOptions{ + RawOptions: o, + }, + }, nil +} + +// Complete builds the finalized options. +func (o *ValidatedOptions) Complete() (*Options, error) { + return &Options{ + completedOptions: &completedOptions{ + TargetACR: o.TargetACR, + SourceRegistry: o.SourceRegistry, + Repository: o.Repository, + Digest: o.Digest, + CopyFrom: o.CopyFrom, + ImageFilePath: o.ImageFilePath, + ImageTarFileName: o.ImageTarFileName, + ImageMetadataFileName: o.ImageMetadataFileName, + ImageTarSAS: o.ImageTarSAS, + ImageMetadataSAS: o.ImageMetadataSAS, + PullSecretKV: o.PullSecretKV, + PullSecretName: o.PullSecretName, + DryRun: o.DryRun, + }, + }, nil +} + +// NewRunner creates a runner from completed options. +func (o *Options) NewRunner(logger logr.Logger) *Runner { + return &Runner{ + opts: o, + logger: logger, + } +} diff --git a/tools/image-mirror/sync/options_test.go b/tools/image-mirror/sync/options_test.go new file mode 100644 index 00000000..ccc61059 --- /dev/null +++ b/tools/image-mirror/sync/options_test.go @@ -0,0 +1,190 @@ +package sync + +import ( + "testing" +) + +func TestRawOptions_Validate_RegistryMode(t *testing.T) { + tests := []struct { + name string + opts RawOptions + wantErr string + }{ + { + name: "missing target ACR", + opts: RawOptions{}, + wantErr: "the target ACR must be provided with --target-acr", + }, + { + name: "missing repository", + opts: RawOptions{ + TargetACR: "myacr", + }, + wantErr: "the repository must be provided with --repository", + }, + { + name: "missing source registry", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + }, + wantErr: "the source registry must be provided with --source-registry for registry mode", + }, + { + name: "missing digest", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + SourceRegistry: "source.azurecr.io", + }, + wantErr: "the digest must be provided with --digest for registry mode", + }, + { + name: "missing pull secret KV", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + SourceRegistry: "source.azurecr.io", + Digest: "sha256:abc123", + }, + wantErr: "the pull secret KeyVault must be provided with --pull-secret-kv for registry mode", + }, + { + name: "missing pull secret name", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + SourceRegistry: "source.azurecr.io", + Digest: "sha256:abc123", + PullSecretKV: "mykeyvault", + }, + wantErr: "the pull secret name must be provided with --pull-secret for registry mode", + }, + { + name: "valid registry mode options", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + SourceRegistry: "source.azurecr.io", + Digest: "sha256:abc123", + PullSecretKV: "mykeyvault", + PullSecretName: "mysecret", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.opts.Validate() + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error %q, got nil", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("expected error %q, got %q", tt.wantErr, err.Error()) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + }) + } +} + +func TestRawOptions_Validate_OCILayoutMode(t *testing.T) { + tests := []struct { + name string + opts RawOptions + wantErr string + }{ + { + name: "missing image tar file name", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + CopyFrom: "oci-layout", + }, + wantErr: "the image tar file name must be provided with --image-tar-file for oci-layout mode", + }, + { + name: "missing image metadata file name", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + CopyFrom: "oci-layout", + ImageTarFileName: "image.tar", + }, + wantErr: "the image metadata file name must be provided with --image-metadata-file for oci-layout mode", + }, + { + name: "valid oci-layout mode options", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + CopyFrom: "oci-layout", + ImageTarFileName: "image.tar", + ImageMetadataFileName: "metadata.json", + }, + }, + { + name: "valid oci-layout mode with SAS URLs", + opts: RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + CopyFrom: "oci-layout", + ImageTarFileName: "image.tar", + ImageMetadataFileName: "metadata.json", + ImageTarSAS: "https://storage.blob.core.windows.net/container/image.tar?sig=abc", + ImageMetadataSAS: "https://storage.blob.core.windows.net/container/metadata.json?sig=abc", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.opts.Validate() + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error %q, got nil", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("expected error %q, got %q", tt.wantErr, err.Error()) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + }) + } +} + +func TestValidatedOptions_Complete(t *testing.T) { + raw := &RawOptions{ + TargetACR: "myacr", + Repository: "myrepo", + SourceRegistry: "source.azurecr.io", + Digest: "sha256:abc123", + PullSecretKV: "mykeyvault", + PullSecretName: "mysecret", + DryRun: true, + } + validated, err := raw.Validate() + if err != nil { + t.Fatalf("unexpected validation error: %v", err) + } + completed, err := validated.Complete() + if err != nil { + t.Fatalf("unexpected completion error: %v", err) + } + if completed.TargetACR != "myacr" { + t.Errorf("expected TargetACR 'myacr', got %q", completed.TargetACR) + } + if completed.Repository != "myrepo" { + t.Errorf("expected Repository 'myrepo', got %q", completed.Repository) + } + if completed.DryRun != true { + t.Errorf("expected DryRun true, got false") + } +} diff --git a/tools/image-mirror/sync/runner.go b/tools/image-mirror/sync/runner.go new file mode 100644 index 00000000..508ae00c --- /dev/null +++ b/tools/image-mirror/sync/runner.go @@ -0,0 +1,340 @@ +package sync + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/go-logr/logr" +) + +const ( + maxRetries = 5 + orasUsername = "00000000-0000-0000-0000-000000000000" + copyFromOCI = "oci-layout" +) + +// Runner executes the image mirror sync operation. +type Runner struct { + opts *Options + logger logr.Logger +} + +// Run executes the image mirror sync operation. +func (r *Runner) Run(ctx context.Context) error { + if r.opts.CopyFrom == copyFromOCI { + return r.copyImageFromOCILayout(ctx) + } + return r.copyImageFromRegistry(ctx) +} + +func (r *Runner) copyImageFromRegistry(ctx context.Context) error { + // Check if source and target are the same + acrDomainSuffix, err := r.getACRDomainSuffix(ctx) + if err != nil { + return fmt.Errorf("failed to get ACR domain suffix: %w", err) + } + if r.opts.SourceRegistry == r.opts.TargetACR+acrDomainSuffix { + r.logger.Info("Source and target registry are the same. No mirroring needed.") + return nil + } + + // Create temp directory for auth config + tmpDir, err := os.MkdirTemp("", "image-mirror-*") + if err != nil { + return fmt.Errorf("failed to create temp directory: %w", err) + } + defer func() { _ = os.RemoveAll(tmpDir) }() + + containersDir := filepath.Join(tmpDir, "containers") + if err := os.MkdirAll(containersDir, 0o700); err != nil { + return fmt.Errorf("failed to create containers directory: %w", err) + } + authJSON := filepath.Join(containersDir, "auth.json") + + // Fetch pull secret from KeyVault + r.logger.Info("Fetching pull secret from KeyVault", "vault", r.opts.PullSecretKV, "secret", r.opts.PullSecretName) + if err := r.runCommand(ctx, "az", "keyvault", "secret", "download", + "--vault-name", r.opts.PullSecretKV, + "--name", r.opts.PullSecretName, + "-e", "base64", + "--file", authJSON, + ); err != nil { + return fmt.Errorf("failed to download pull secret: %w", err) + } + + // ACR login to target registry + r.logger.Info("Logging into target ACR", "acr", r.opts.TargetACR) + loginServer, accessToken, err := r.acrLogin(ctx) + if err != nil { + return fmt.Errorf("failed to login to ACR: %w", err) + } + + // oras login to target + if err := r.orasLogin(ctx, loginServer, accessToken, authJSON); err != nil { + return fmt.Errorf("failed to oras login to target ACR: %w", err) + } + + if r.opts.DryRun { + r.logger.Info("DRY_RUN is enabled. Exiting without making changes.") + return nil + } + + // Mirror image + digestNoPrefix := strings.TrimPrefix(r.opts.Digest, "sha256:") + srcImage := fmt.Sprintf("%s/%s@%s", r.opts.SourceRegistry, r.opts.Repository, r.opts.Digest) + targetImage := fmt.Sprintf("%s/%s:%s", loginServer, r.opts.Repository, digestNoPrefix) + r.logger.Info("Mirroring image", "src", srcImage, "target", targetImage) + r.logger.Info("The image will still be available under its original digest in the target registry", "digest", r.opts.Digest) + + return r.runCommand(ctx, "oras", "cp", srcImage, targetImage, + "--from-registry-config", authJSON, + "--to-registry-config", authJSON, + ) +} + +func (r *Runner) copyImageFromOCILayout(ctx context.Context) error { + imageTarFile, imageMetadataFile, err := r.resolveImageFiles(ctx) + if err != nil { + return err + } + + // Read build_tag from metadata + buildTag, err := r.readBuildTag(imageMetadataFile) + if err != nil { + return err + } + r.logger.Info("Resolved build tag", "buildTag", buildTag) + + // Get ACR login server + acrDomainSuffix, err := r.getACRDomainSuffix(ctx) + if err != nil { + return fmt.Errorf("failed to get ACR domain suffix: %w", err) + } + targetACRLoginServer := r.opts.TargetACR + acrDomainSuffix + + // ACR login + r.logger.Info("Logging into target ACR", "acr", r.opts.TargetACR) + _, accessToken, err := r.acrLogin(ctx) + if err != nil { + return fmt.Errorf("failed to login to ACR: %w", err) + } + + // oras login + if err := r.orasLogin(ctx, targetACRLoginServer, accessToken, ""); err != nil { + return fmt.Errorf("failed to oras login to target ACR: %w", err) + } + + if r.opts.DryRun { + r.logger.Info("DRY_RUN is enabled. Exiting without making changes.") + return nil + } + + // Copy from OCI layout + targetImage := fmt.Sprintf("%s/%s:%s", targetACRLoginServer, r.opts.Repository, buildTag) + r.logger.Info("Copying image from OCI layout", "source", imageTarFile, "target", targetImage) + return r.runCommand(ctx, "oras", "cp", + "--from-oci-layout", fmt.Sprintf("%s:%s", imageTarFile, buildTag), + targetImage, + ) +} + +// resolveImageFiles locates or downloads image tar and metadata files. +func (r *Runner) resolveImageFiles(ctx context.Context) (imageTarFile, imageMetadataFile string, err error) { + imageFilePath := r.opts.ImageFilePath + if imageFilePath == "" { + imageFilePath, err = os.Getwd() + if err != nil { + return "", "", fmt.Errorf("failed to get working directory: %w", err) + } + } + + // If SAS URLs are provided, download files + if r.opts.ImageTarSAS != "" { + r.logger.Info("Downloading image tar from SAS URL") + imageTarFile = filepath.Join(imageFilePath, r.opts.ImageTarFileName) + if err := r.downloadFromSAS(ctx, r.opts.ImageTarSAS, imageTarFile); err != nil { + return "", "", fmt.Errorf("failed to download image tar: %w", err) + } + } else { + imageTarFile = filepath.Join(imageFilePath, r.opts.ImageTarFileName) + } + + if r.opts.ImageMetadataSAS != "" { + r.logger.Info("Downloading image metadata from SAS URL") + imageMetadataFile = filepath.Join(imageFilePath, r.opts.ImageMetadataFileName) + if err := r.downloadFromSAS(ctx, r.opts.ImageMetadataSAS, imageMetadataFile); err != nil { + return "", "", fmt.Errorf("failed to download image metadata: %w", err) + } + } else { + imageMetadataFile = filepath.Join(imageFilePath, r.opts.ImageMetadataFileName) + } + + // Validate files exist + if _, err := os.Stat(imageTarFile); err != nil { + return "", "", fmt.Errorf("image tar file %s does not exist at path %s: %w", r.opts.ImageTarFileName, imageFilePath, err) + } + if _, err := os.Stat(imageMetadataFile); err != nil { + return "", "", fmt.Errorf("image metadata file %s does not exist at path %s: %w", r.opts.ImageMetadataFileName, imageFilePath, err) + } + + return imageTarFile, imageMetadataFile, nil +} + +// downloadFromSAS downloads a file from a SAS URL with retry. +func (r *Runner) downloadFromSAS(ctx context.Context, sasURL, destPath string) error { + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("failed to create directory for %s: %w", destPath, err) + } + + return retry(ctx, maxRetries, func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sasURL, nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to download from SAS URL: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %d downloading from SAS URL", resp.StatusCode) + } + + f, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file %s: %w", destPath, err) + } + defer func() { _ = f.Close() }() + + if _, err := io.Copy(f, resp.Body); err != nil { + return fmt.Errorf("failed to write file %s: %w", destPath, err) + } + + return nil + }) +} + +// readBuildTag reads the build_tag field from the image metadata JSON file. +func (r *Runner) readBuildTag(metadataFile string) (string, error) { + data, err := os.ReadFile(metadataFile) + if err != nil { + return "", fmt.Errorf("failed to read metadata file: %w", err) + } + var metadata struct { + BuildTag string `json:"build_tag"` + } + if err := json.Unmarshal(data, &metadata); err != nil { + return "", fmt.Errorf("failed to parse metadata file: %w", err) + } + if metadata.BuildTag == "" { + return "", fmt.Errorf("build_tag not found in %s", metadataFile) + } + return metadata.BuildTag, nil +} + +// acrLogin performs az acr login with retry, returning the login server and access token. +func (r *Runner) acrLogin(ctx context.Context) (loginServer, accessToken string, err error) { + var output []byte + err = retry(ctx, maxRetries, func() error { + cmd := exec.CommandContext(ctx, "az", "acr", "login", + "--name", r.opts.TargetACR, + "--expose-token", + "--only-show-errors", + "--output", "json", + ) + var execErr error + output, execErr = cmd.Output() + if execErr != nil { + var exitErr *exec.ExitError + if errors.As(execErr, &exitErr) { + return fmt.Errorf("az acr login failed: %s", string(exitErr.Stderr)) + } + return fmt.Errorf("az acr login failed: %w", execErr) + } + return nil + }) + if err != nil { + return "", "", err + } + + var response struct { + LoginServer string `json:"loginServer"` + AccessToken string `json:"accessToken"` + } + if err := json.Unmarshal(output, &response); err != nil { + return "", "", fmt.Errorf("failed to parse ACR login response: %w", err) + } + + return response.LoginServer, response.AccessToken, nil +} + +// orasLogin performs oras login to the target registry. +func (r *Runner) orasLogin(ctx context.Context, loginServer, accessToken, registryConfig string) error { + args := []string{"login", loginServer, + "--username", orasUsername, + "--password-stdin", + } + if registryConfig != "" { + args = append(args, "--registry-config", registryConfig) + } + + cmd := exec.CommandContext(ctx, "oras", args...) + cmd.Stdin = strings.NewReader(accessToken) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +// getACRDomainSuffix returns the ACR domain suffix for the current cloud. +func (r *Runner) getACRDomainSuffix(ctx context.Context) (string, error) { + cmd := exec.CommandContext(ctx, "az", "cloud", "show", + "--query", "suffixes.acrLoginServerEndpoint", + "--output", "tsv", + ) + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("failed to get ACR domain suffix: %w", err) + } + return strings.TrimSpace(string(output)), nil +} + +// runCommand runs an external command, streaming stdout/stderr. +func (r *Runner) runCommand(ctx context.Context, name string, args ...string) error { + r.logger.V(1).Info("Running command", "cmd", name, "args", args) + cmd := exec.CommandContext(ctx, name, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +// retry executes fn with exponential backoff. +func retry(ctx context.Context, maxAttempts int, fn func() error) error { + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + lastErr = fn() + if lastErr == nil { + return nil + } + if attempt < maxAttempts { + delay := time.Duration(math.Pow(2, float64(attempt))) * time.Second + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + } + } + return fmt.Errorf("command failed after %d attempts: %w", maxAttempts, lastErr) +} diff --git a/tools/image-mirror/sync/runner_test.go b/tools/image-mirror/sync/runner_test.go new file mode 100644 index 00000000..3efea078 --- /dev/null +++ b/tools/image-mirror/sync/runner_test.go @@ -0,0 +1,336 @@ +package sync + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/go-logr/logr/funcr" +) + +func testLogger() logr.Logger { + return funcr.New(func(prefix, args string) { + // discard for tests + }, funcr.Options{}) +} + +func TestReadBuildTag(t *testing.T) { + tests := []struct { + name string + content string + wantTag string + wantErr bool + }{ + { + name: "valid metadata", + content: `{"build_tag": "v1.0.0-abc123"}`, + wantTag: "v1.0.0-abc123", + }, + { + name: "missing build_tag", + content: `{"other": "value"}`, + wantErr: true, + }, + { + name: "empty build_tag", + content: `{"build_tag": ""}`, + wantErr: true, + }, + { + name: "invalid JSON", + content: `not json`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + metadataFile := filepath.Join(tmpDir, "metadata.json") + if err := os.WriteFile(metadataFile, []byte(tt.content), 0o644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + r := &Runner{logger: testLogger()} + tag, err := r.readBuildTag(metadataFile) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tag != tt.wantTag { + t.Errorf("expected tag %q, got %q", tt.wantTag, tag) + } + }) + } +} + +func TestReadBuildTag_FileNotFound(t *testing.T) { + r := &Runner{logger: testLogger()} + _, err := r.readBuildTag("/nonexistent/path/metadata.json") + if err == nil { + t.Fatal("expected error for nonexistent file, got nil") + } +} + +func TestDownloadFromSAS(t *testing.T) { + expectedContent := "test file content for image tar" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, expectedContent) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "subdir", "downloaded.tar") + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{}}, + logger: testLogger(), + } + + ctx := context.Background() + err := r.downloadFromSAS(ctx, server.URL, destPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + if string(data) != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, string(data)) + } +} + +func TestDownloadFromSAS_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.tar") + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{}}, + logger: testLogger(), + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := r.downloadFromSAS(ctx, server.URL, destPath) + if err == nil { + t.Fatal("expected error for HTTP 403, got nil") + } +} + +func TestDownloadFromSAS_ContextCancel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Delay so the context can cancel first + time.Sleep(2 * time.Second) + _, _ = fmt.Fprint(w, "should not complete") + })) + defer server.Close() + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.tar") + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{}}, + logger: testLogger(), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + err := r.downloadFromSAS(ctx, server.URL, destPath) + if err == nil { + t.Fatal("expected error for cancelled context, got nil") + } +} + +func TestResolveImageFiles_LocalFiles(t *testing.T) { + tmpDir := t.TempDir() + + tarFile := filepath.Join(tmpDir, "image.tar") + metadataFile := filepath.Join(tmpDir, "metadata.json") + if err := os.WriteFile(tarFile, []byte("tar content"), 0o644); err != nil { + t.Fatalf("failed to write tar file: %v", err) + } + if err := os.WriteFile(metadataFile, []byte(`{"build_tag":"v1"}`), 0o644); err != nil { + t.Fatalf("failed to write metadata file: %v", err) + } + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{ + ImageFilePath: tmpDir, + ImageTarFileName: "image.tar", + ImageMetadataFileName: "metadata.json", + }}, + logger: testLogger(), + } + + ctx := context.Background() + gotTar, gotMeta, err := r.resolveImageFiles(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotTar != tarFile { + t.Errorf("expected tar file %q, got %q", tarFile, gotTar) + } + if gotMeta != metadataFile { + t.Errorf("expected metadata file %q, got %q", metadataFile, gotMeta) + } +} + +func TestResolveImageFiles_MissingTarFile(t *testing.T) { + tmpDir := t.TempDir() + + metadataFile := filepath.Join(tmpDir, "metadata.json") + if err := os.WriteFile(metadataFile, []byte(`{"build_tag":"v1"}`), 0o644); err != nil { + t.Fatalf("failed to write metadata file: %v", err) + } + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{ + ImageFilePath: tmpDir, + ImageTarFileName: "image.tar", + ImageMetadataFileName: "metadata.json", + }}, + logger: testLogger(), + } + + ctx := context.Background() + _, _, err := r.resolveImageFiles(ctx) + if err == nil { + t.Fatal("expected error for missing tar file, got nil") + } +} + +func TestResolveImageFiles_SASDownload(t *testing.T) { + tarContent := "fake tar content" + metadataContent, _ := json.Marshal(map[string]string{"build_tag": "v1.0.0"}) + + mux := http.NewServeMux() + mux.HandleFunc("/image.tar", func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, tarContent) + }) + mux.HandleFunc("/metadata.json", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(metadataContent) + }) + server := httptest.NewServer(mux) + defer server.Close() + + tmpDir := t.TempDir() + + r := &Runner{ + opts: &Options{completedOptions: &completedOptions{ + ImageFilePath: tmpDir, + ImageTarFileName: "image.tar", + ImageMetadataFileName: "metadata.json", + ImageTarSAS: server.URL + "/image.tar", + ImageMetadataSAS: server.URL + "/metadata.json", + }}, + logger: testLogger(), + } + + ctx := context.Background() + gotTar, gotMeta, err := r.resolveImageFiles(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify files were downloaded + data, err := os.ReadFile(gotTar) + if err != nil { + t.Fatalf("failed to read downloaded tar: %v", err) + } + if string(data) != tarContent { + t.Errorf("expected tar content %q, got %q", tarContent, string(data)) + } + + data, err = os.ReadFile(gotMeta) + if err != nil { + t.Fatalf("failed to read downloaded metadata: %v", err) + } + if string(data) != string(metadataContent) { + t.Errorf("expected metadata content %q, got %q", string(metadataContent), string(data)) + } +} + +func TestRetry_Success(t *testing.T) { + attempts := 0 + err := retry(context.Background(), 3, func() error { + attempts++ + if attempts < 3 { + return fmt.Errorf("not yet") + } + return nil + }) + if err != nil { + t.Fatalf("expected success after retries, got: %v", err) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestRetry_AllFail(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + attempts := 0 + err := retry(ctx, 3, func() error { + attempts++ + return fmt.Errorf("always fail") + }) + if err == nil { + t.Fatal("expected error after all retries, got nil") + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestRetry_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + attempts := 0 + err := retry(ctx, 5, func() error { + attempts++ + return fmt.Errorf("fail") + }) + if err == nil { + t.Fatal("expected error for cancelled context, got nil") + } + // Should have done 1 attempt, then context was cancelled before the sleep + if attempts != 1 { + t.Errorf("expected 1 attempt before context cancel, got %d", attempts) + } +} + +func TestRetry_ImmediateSuccess(t *testing.T) { + attempts := 0 + err := retry(context.Background(), 5, func() error { + attempts++ + return nil + }) + if err != nil { + t.Fatalf("expected immediate success, got: %v", err) + } + if attempts != 1 { + t.Errorf("expected 1 attempt, got %d", attempts) + } +}