diff --git a/cmd/signup.go b/cmd/signup.go index 916f663..642d2c1 100644 --- a/cmd/signup.go +++ b/cmd/signup.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "net/mail" "os" "strings" "time" @@ -11,27 +12,51 @@ import ( ) var ( - signupGit bool - signupOrg string + signupGit bool + signupOrg string + signupEmail string + signupSSHKey string ) // signupWithGit authenticates using the Shell Auth flow with git email + SSH key. func signupWithGit() error { - // Step 1: Get git email - fmt.Print("Looking up git email... ") - email, err := auth.GetGitEmail() - if err != nil { - fmt.Println("✗") - return err + // Step 1: Get email + var email string + if signupEmail != "" { + // Validate the provided email + if _, err := mail.ParseAddress(signupEmail); err != nil { + return fmt.Errorf("invalid email address %q: %w", signupEmail, err) + } + email = signupEmail + fmt.Printf("Using email: %s\n", email) + } else { + fmt.Print("Looking up git email... ") + var err error + email, err = auth.GetGitEmail() + if err != nil { + fmt.Println("✗") + return err + } + fmt.Println(email) } - fmt.Println(email) - // Step 2: Find SSH public key - fmt.Print("Looking up SSH public key... ") - sshPubKey, err := auth.FindSSHPublicKey() - if err != nil { - fmt.Println("✗") - return err + // Step 2: Get SSH public key + var sshPubKey string + if signupSSHKey != "" { + // Read and validate the provided SSH key file + pubKey, err := auth.ReadAndValidateSSHPublicKey(signupSSHKey) + if err != nil { + return err + } + sshPubKey = pubKey + } else { + fmt.Print("Looking up SSH public key... ") + var err error + sshPubKey, err = auth.FindSSHPublicKey() + if err != nil { + fmt.Println("✗") + return err + } } // Show truncated key for confirmation keyParts := strings.Fields(sshPubKey) @@ -40,7 +65,7 @@ func signupWithGit() error { if len(keyPreview) > 16 { keyPreview = keyPreview[:8] + "..." + keyPreview[len(keyPreview)-8:] } - fmt.Printf("%s %s\n", keyType, keyPreview) + fmt.Printf("SSH key: %s %s\n", keyType, keyPreview) // Step 3: Initiate shell auth fmt.Println("\nInitiating authentication...") @@ -147,9 +172,11 @@ var signupCmd = &cobra.Command{ By default, signup uses your git email and SSH public key to create an account. A verification email is sent — click the link and you're in. - vers signup Sign up with git email + SSH key (default) - vers signup --org myorg Pick org non-interactively (for scripts/agents) - vers signup --git=false Prompt for an API key instead + vers signup Auto-detect git email + SSH key + vers signup --email me@co.com Use a specific email + vers signup --ssh-key ~/.ssh/id_rsa.pub Use a specific SSH public key + vers signup --org myorg Pick org non-interactively + vers signup --git=false Prompt for an API key instead If you already have an account, this will log you in.`, RunE: func(cmd *cobra.Command, args []string) error { @@ -183,4 +210,6 @@ func init() { rootCmd.AddCommand(signupCmd) signupCmd.Flags().BoolVar(&signupGit, "git", true, "Authenticate using your git email and SSH key (default: true)") signupCmd.Flags().StringVar(&signupOrg, "org", "", "Organization name (skips interactive selection)") + signupCmd.Flags().StringVar(&signupEmail, "email", "", "Email address (default: git config user.email)") + signupCmd.Flags().StringVar(&signupSSHKey, "ssh-key", "", "Path to SSH public key file (default: auto-detect)") } diff --git a/cmd/signup_test.go b/cmd/signup_test.go index a783660..1dd0a0b 100644 --- a/cmd/signup_test.go +++ b/cmd/signup_test.go @@ -73,3 +73,35 @@ func TestSignupNoUnexpectedFlags(t *testing.T) { t.Error("signup should not have a --token flag (that's login's job)") } } + +// TestSignupEmailFlag verifies --email flag exists with empty default. +func TestSignupEmailFlag(t *testing.T) { + cmd, _, err := rootCmd.Find([]string{"signup"}) + if err != nil { + t.Fatalf("Find(signup) returned error: %v", err) + } + + flag := cmd.Flags().Lookup("email") + if flag == nil { + t.Fatal("signup command has no --email flag") + } + if flag.DefValue != "" { + t.Errorf("expected --email default value %q, got %q", "", flag.DefValue) + } +} + +// TestSignupSSHKeyFlag verifies --ssh-key flag exists with empty default. +func TestSignupSSHKeyFlag(t *testing.T) { + cmd, _, err := rootCmd.Find([]string{"signup"}) + if err != nil { + t.Fatalf("Find(signup) returned error: %v", err) + } + + flag := cmd.Flags().Lookup("ssh-key") + if flag == nil { + t.Fatal("signup command has no --ssh-key flag") + } + if flag.DefValue != "" { + t.Errorf("expected --ssh-key default value %q, got %q", "", flag.DefValue) + } +} diff --git a/internal/auth/shellauth.go b/internal/auth/shellauth.go index 1484d67..0209f6e 100644 --- a/internal/auth/shellauth.go +++ b/internal/auth/shellauth.go @@ -119,6 +119,54 @@ func FindSSHPublicKey() (string, error) { return "", fmt.Errorf("no SSH public key found — checked: %s\nGenerate one with: ssh-keygen -t ed25519", strings.Join(candidates, ", ")) } +// validSSHKeyTypes are the accepted SSH public key type prefixes. +var validSSHKeyTypes = map[string]bool{ + "ssh-ed25519": true, + "ssh-rsa": true, + "ecdsa-sha2-nistp256": true, + "ecdsa-sha2-nistp384": true, + "ecdsa-sha2-nistp521": true, + "sk-ssh-ed25519@openssh.com": true, + "sk-ecdsa-sha2-nistp256@openssh.com": true, +} + +// ReadAndValidateSSHPublicKey reads an SSH public key from a file path and validates it. +// Returns the key contents on success. +func ReadAndValidateSSHPublicKey(path string) (string, error) { + info, err := os.Stat(path) + if err != nil { + return "", fmt.Errorf("SSH key file not found: %s", path) + } + if info.IsDir() { + return "", fmt.Errorf("SSH key path is a directory, not a file: %s", path) + } + if info.Size() > 16*1024 { + return "", fmt.Errorf("SSH key file too large (%d bytes) — expected a public key", info.Size()) + } + + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read SSH key file: %w", err) + } + + key := strings.TrimSpace(string(data)) + if key == "" { + return "", fmt.Errorf("SSH key file is empty: %s", path) + } + + // Validate format: should be " [comment]" + parts := strings.Fields(key) + if len(parts) < 2 { + return "", fmt.Errorf("invalid SSH public key format in %s — expected \" [comment]\"", path) + } + + if !validSSHKeyTypes[parts[0]] { + return "", fmt.Errorf("unrecognized SSH key type %q in %s — expected one of: ssh-ed25519, ssh-rsa, ecdsa-sha2-*", parts[0], path) + } + + return key, nil +} + // shellAuthBaseURL returns the base URL for shell auth endpoints. func shellAuthBaseURL() (string, error) { versURL, err := GetVersUrl() diff --git a/internal/auth/shellauth_test.go b/internal/auth/shellauth_test.go new file mode 100644 index 0000000..3f6dfb3 --- /dev/null +++ b/internal/auth/shellauth_test.go @@ -0,0 +1,146 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" +) + +func TestReadAndValidateSSHPublicKey_Valid(t *testing.T) { + tests := []struct { + name string + content string + }{ + {"ed25519", "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIExampleKeyDataHere user@host"}, + {"rsa", "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQExample user@host"}, + {"ecdsa", "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY= user@host"}, + {"no comment", "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIExampleKeyDataHere"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "id_test.pub") + os.WriteFile(path, []byte(tt.content), 0644) + + key, err := ReadAndValidateSSHPublicKey(path) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if key != tt.content { + t.Errorf("expected key %q, got %q", tt.content, key) + } + }) + } +} + +func TestReadAndValidateSSHPublicKey_Invalid(t *testing.T) { + tests := []struct { + name string + setup func(dir string) string // returns path + wantErr string + }{ + { + name: "file not found", + setup: func(dir string) string { + return filepath.Join(dir, "nonexistent") + }, + wantErr: "not found", + }, + { + name: "is a directory", + setup: func(dir string) string { + p := filepath.Join(dir, "subdir") + os.Mkdir(p, 0755) + return p + }, + wantErr: "directory", + }, + { + name: "empty file", + setup: func(dir string) string { + p := filepath.Join(dir, "empty.pub") + os.WriteFile(p, []byte(""), 0644) + return p + }, + wantErr: "empty", + }, + { + name: "whitespace only", + setup: func(dir string) string { + p := filepath.Join(dir, "blank.pub") + os.WriteFile(p, []byte(" \n \n"), 0644) + return p + }, + wantErr: "empty", + }, + { + name: "single field no base64", + setup: func(dir string) string { + p := filepath.Join(dir, "bad.pub") + os.WriteFile(p, []byte("ssh-ed25519"), 0644) + return p + }, + wantErr: "invalid SSH public key format", + }, + { + name: "unknown key type", + setup: func(dir string) string { + p := filepath.Join(dir, "bad.pub") + os.WriteFile(p, []byte("ssh-dsa AAAAB3NzaC1kc3MAAACB user@host"), 0644) + return p + }, + wantErr: "unrecognized SSH key type", + }, + { + name: "private key (too large isn't the check here, but wrong format)", + setup: func(dir string) string { + p := filepath.Join(dir, "id_rsa") + os.WriteFile(p, []byte("-----BEGIN OPENSSH PRIVATE KEY-----\ndata\n-----END OPENSSH PRIVATE KEY-----"), 0644) + return p + }, + wantErr: "unrecognized SSH key type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + path := tt.setup(dir) + + _, err := ReadAndValidateSSHPublicKey(path) + if err == nil { + t.Fatal("expected error, got nil") + } + if !contains(err.Error(), tt.wantErr) { + t.Errorf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestReadAndValidateSSHPublicKey_TooLarge(t *testing.T) { + path := filepath.Join(t.TempDir(), "big.pub") + // 17KB file + os.WriteFile(path, make([]byte, 17*1024), 0644) + + _, err := ReadAndValidateSSHPublicKey(path) + if err == nil { + t.Fatal("expected error for oversized file") + } + if !contains(err.Error(), "too large") { + t.Errorf("expected 'too large' error, got: %v", err) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && containsStr(s, substr) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +}