diff --git a/cmd/tlsc/ca.go b/cmd/tlsc/ca.go index d3f6351..167b3ce 100644 --- a/cmd/tlsc/ca.go +++ b/cmd/tlsc/ca.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -12,7 +13,7 @@ import ( "github.com/dyaa/tlsc/internal/services/ca" ) -func runCA(args []string) { +func runCA(ctx context.Context, args []string) { if len(args) == 0 { printCAUsage() os.Exit(1) @@ -20,17 +21,17 @@ func runCA(args []string) { switch args[0] { case "init": - runCAInit(args[1:]) + runCAInit(ctx, args[1:]) case "install": - runCAInstall(args[1:]) + runCAInstall(ctx, args[1:]) case "uninstall": runCAUninstall(args[1:]) case "info": - runCAInfo(args[1:]) + runCAInfo(ctx, args[1:]) case "renew": - runCARenew(args[1:]) + runCARenew(ctx, args[1:]) case "export": - runCAExport(args[1:]) + runCAExport(ctx, args[1:]) case "help", "-h", "--help": printCAUsage() default: @@ -38,7 +39,7 @@ func runCA(args []string) { } } -func runCAInit(args []string) { +func runCAInit(ctx context.Context, args []string) { fs := flag.NewFlagSet("ca init", flag.ExitOnError) var ( caPath string @@ -75,7 +76,7 @@ func runCAInit(args []string) { opts.Validity = validity opts.KeyType = keyType - authority, err := ca.Init(caPath, opts) + authority, err := ca.Init(ctx, caPath, opts) if err != nil { fatal("%s", err) } @@ -94,13 +95,13 @@ func runCAInit(args []string) { fmt.Printf("\n Next: run '\033[1mtlsc ca install\033[0m' to trust it system-wide\n\n") } -func runCAInstall(args []string) { +func runCAInstall(ctx context.Context, args []string) { fs := flag.NewFlagSet("ca install", flag.ExitOnError) var caPath string fs.StringVar(&caPath, "ca-path", "", "CA directory (default: ~/.tlsc/ca/)") fs.Parse(normArgs(args)) - authority, err := ca.Load(caPath) + authority, err := ca.Load(ctx, caPath) if err != nil { fatal("%s", err) } @@ -117,13 +118,13 @@ func runCAUninstall(_ []string) { fmt.Printf("\n \033[32m✓ CA removed from system trust store\033[0m\n\n") } -func runCAInfo(args []string) { +func runCAInfo(ctx context.Context, args []string) { fs := flag.NewFlagSet("ca info", flag.ExitOnError) var caPath string fs.StringVar(&caPath, "ca-path", "", "CA directory (default: ~/.tlsc/ca/)") fs.Parse(normArgs(args)) - authority, err := ca.Load(caPath) + authority, err := ca.Load(ctx, caPath) if err != nil { fatal("%s", err) } @@ -147,13 +148,13 @@ func runCAInfo(args []string) { fmt.Println() } -func runCARenew(args []string) { +func runCARenew(ctx context.Context, args []string) { fs := flag.NewFlagSet("ca renew", flag.ExitOnError) var caPath string fs.StringVar(&caPath, "ca-path", "", "CA directory (default: ~/.tlsc/ca/)") fs.Parse(normArgs(args)) - authority, err := ca.Renew(caPath) + authority, err := ca.Renew(ctx, caPath) if err != nil { fatal("%s", err) } @@ -164,13 +165,13 @@ func runCARenew(args []string) { fmt.Println() } -func runCAExport(args []string) { +func runCAExport(ctx context.Context, args []string) { fs := flag.NewFlagSet("ca export", flag.ExitOnError) var caPath string fs.StringVar(&caPath, "ca-path", "", "CA directory (default: ~/.tlsc/ca/)") fs.Parse(normArgs(args)) - authority, err := ca.Load(caPath) + authority, err := ca.Load(ctx, caPath) if err != nil { fatal("%s", err) } diff --git a/cmd/tlsc/check.go b/cmd/tlsc/check.go index f8affdd..2ab74ec 100644 --- a/cmd/tlsc/check.go +++ b/cmd/tlsc/check.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -10,7 +11,7 @@ import ( "github.com/dyaa/tlsc/internal/output" ) -func runCheck(args []string) { +func runCheck(ctx context.Context, args []string) { fs := flag.NewFlagSet("check", flag.ExitOnError) var ( @@ -70,7 +71,7 @@ func runCheck(args []string) { exitCode := 0 if len(hosts) == 1 { - result, err := svc.Check(hosts[0], opts) + result, err := svc.Check(ctx, hosts[0], opts) if err != nil { if jsonOut { output.JSONError(hosts[0], err) @@ -92,7 +93,7 @@ func runCheck(args []string) { } } } else { - results := svc.CheckBatch(hosts, opts) + results := svc.CheckBatch(ctx, hosts, opts) for _, host := range hosts { r := results[host] if r.Err != nil { diff --git a/cmd/tlsc/convert.go b/cmd/tlsc/convert.go index 178471f..110b37c 100644 --- a/cmd/tlsc/convert.go +++ b/cmd/tlsc/convert.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -8,7 +9,7 @@ import ( "github.com/dyaa/tlsc/internal/services/convert" ) -func runConvert(args []string) { +func runConvert(ctx context.Context, args []string) { fs := flag.NewFlagSet("convert", flag.ExitOnError) var ( @@ -34,7 +35,7 @@ func runConvert(args []string) { inputPath := positional[0] convSvc := convert.New() - if err := convSvc.Convert(inputPath, outPath, format); err != nil { + if err := convSvc.Convert(ctx, inputPath, outPath, format); err != nil { fatal("%s", err) } diff --git a/cmd/tlsc/csr.go b/cmd/tlsc/csr.go index 1a9e867..eee56c3 100644 --- a/cmd/tlsc/csr.go +++ b/cmd/tlsc/csr.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -10,7 +11,7 @@ import ( "github.com/dyaa/tlsc/internal/services/generate" ) -func runCSR(args []string) { +func runCSR(ctx context.Context, args []string) { if len(args) == 0 { printCSRUsage() os.Exit(1) @@ -18,9 +19,9 @@ func runCSR(args []string) { switch args[0] { case "generate": - runCSRGenerate(args[1:]) + runCSRGenerate(ctx, args[1:]) case "sign": - runCSRSign(args[1:]) + runCSRSign(ctx, args[1:]) case "help", "-h", "--help": printCSRUsage() default: @@ -28,7 +29,7 @@ func runCSR(args []string) { } } -func runCSRGenerate(args []string) { +func runCSRGenerate(ctx context.Context, args []string) { fs := flag.NewFlagSet("csr generate", flag.ExitOnError) var ( @@ -71,7 +72,7 @@ func runCSRGenerate(args []string) { opts.State = state csrSvc := generate.New() - result, err := csrSvc.GenerateCSR(hosts, opts) + result, err := csrSvc.GenerateCSR(ctx, hosts, opts) if err != nil { fatal("%s", err) } @@ -83,7 +84,7 @@ func runCSRGenerate(args []string) { fmt.Printf("\n Next: run '\033[1mtlsc csr sign %s\033[0m' to sign with your CA\n\n", result.CSRPath) } -func runCSRSign(args []string) { +func runCSRSign(ctx context.Context, args []string) { fs := flag.NewFlagSet("csr sign", flag.ExitOnError) var ( @@ -126,7 +127,7 @@ func runCSRSign(args []string) { } csrSvc := generate.New() - result, err := csrSvc.SignCSR(csrPath, opts) + result, err := csrSvc.SignCSR(ctx, csrPath, opts) if err != nil { fatal("%s", err) } diff --git a/cmd/tlsc/generate.go b/cmd/tlsc/generate.go index 4decd27..5bddd30 100644 --- a/cmd/tlsc/generate.go +++ b/cmd/tlsc/generate.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -12,7 +13,7 @@ import ( var genSvc = generate.New() -func runGenerate(args []string) { +func runGenerate(ctx context.Context, args []string) { fs := flag.NewFlagSet("generate", flag.ExitOnError) var ( @@ -78,12 +79,12 @@ func runGenerate(args []string) { opts.CAPath = caPath } - result, err := genSvc.Generate(hosts, opts) + result, err := genSvc.Generate(ctx, hosts, opts) if err != nil { fatal("%s", err) } - inspected, err := svc.InspectFile(result.CertPath) + inspected, err := svc.InspectFile(ctx, result.CertPath) if err != nil { fmt.Printf("\n \033[32m✓ Created certificate\033[0m\n\n") fmt.Printf(" cert: %s\n key: %s\n\n", result.CertPath, result.KeyPath) diff --git a/cmd/tlsc/inspect.go b/cmd/tlsc/inspect.go index 8c33b2b..835b795 100644 --- a/cmd/tlsc/inspect.go +++ b/cmd/tlsc/inspect.go @@ -1,16 +1,18 @@ package main import ( + "context" "crypto/x509" "encoding/pem" "flag" "fmt" "os" + "github.com/dyaa/tlsc/internal/fileutil" "github.com/dyaa/tlsc/internal/output" ) -func runInspect(args []string) { +func runInspect(ctx context.Context, args []string) { fs := flag.NewFlagSet("inspect", flag.ExitOnError) var ( @@ -32,7 +34,7 @@ func runInspect(args []string) { } path := positional[0] - result, err := svc.InspectFile(path) + result, err := svc.InspectFile(ctx, path) if err != nil { fatal("%s", err) } @@ -53,7 +55,7 @@ func runInspect(args []string) { } func validateChain(certPath, caPath string) (bool, string) { - certData, err := os.ReadFile(certPath) + certData, err := fileutil.ReadLimited(certPath, fileutil.MaxCertFileSize) if err != nil { return false, err.Error() } @@ -66,7 +68,7 @@ func validateChain(certPath, caPath string) (bool, string) { return false, err.Error() } - caData, err := os.ReadFile(caPath) + caData, err := fileutil.ReadLimited(caPath, fileutil.MaxCertFileSize) if err != nil { return false, err.Error() } diff --git a/cmd/tlsc/list.go b/cmd/tlsc/list.go index b8f0de9..06259eb 100644 --- a/cmd/tlsc/list.go +++ b/cmd/tlsc/list.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "fmt" @@ -11,7 +12,7 @@ import ( "github.com/dyaa/tlsc/internal/domain" ) -func runList(args []string) { +func runList(ctx context.Context, args []string) { fs := flag.NewFlagSet("list", flag.ExitOnError) var jsonOut bool @@ -26,7 +27,7 @@ func runList(args []string) { fs.Parse(normArgs(args)) - certs, err := svc.ListCerts(dir) + certs, err := svc.ListCerts(ctx, dir) if err != nil { fatal("%s", err) } diff --git a/cmd/tlsc/main.go b/cmd/tlsc/main.go index 5f51a88..d54c046 100644 --- a/cmd/tlsc/main.go +++ b/cmd/tlsc/main.go @@ -1,10 +1,13 @@ package main import ( + "context" "flag" "fmt" "os" + "os/signal" "strings" + "syscall" "github.com/dyaa/tlsc/internal/services/checker" ) @@ -20,29 +23,32 @@ func main() { os.Exit(1) } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + switch os.Args[1] { case "inspect": - runInspect(os.Args[2:]) + runInspect(ctx, os.Args[2:]) case "generate": - runGenerate(os.Args[2:]) + runGenerate(ctx, os.Args[2:]) case "ca": - runCA(os.Args[2:]) + runCA(ctx, os.Args[2:]) case "verify": - runVerify(os.Args[2:]) + runVerify(ctx, os.Args[2:]) case "csr": - runCSR(os.Args[2:]) + runCSR(ctx, os.Args[2:]) case "renew": - runRenew(os.Args[2:]) + runRenew(ctx, os.Args[2:]) case "convert": - runConvert(os.Args[2:]) + runConvert(ctx, os.Args[2:]) case "list": - runList(os.Args[2:]) + runList(ctx, os.Args[2:]) case "help", "-h", "--help": printUsage() case "version", "-v", "--version": fmt.Println("tlsc " + version) default: - runCheck(os.Args[1:]) + runCheck(ctx, os.Args[1:]) } } diff --git a/cmd/tlsc/renew.go b/cmd/tlsc/renew.go index fe16231..2695180 100644 --- a/cmd/tlsc/renew.go +++ b/cmd/tlsc/renew.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -10,7 +11,7 @@ import ( "github.com/dyaa/tlsc/internal/services/generate" ) -func runRenew(args []string) { +func runRenew(ctx context.Context, args []string) { fs := flag.NewFlagSet("renew", flag.ExitOnError) var ( @@ -44,7 +45,7 @@ func runRenew(args []string) { } renewSvc := generate.New() - result, err := renewSvc.Renew(certPath, opts) + result, err := renewSvc.Renew(ctx, certPath, opts) if err != nil { fatal("%s", err) } diff --git a/cmd/tlsc/verify.go b/cmd/tlsc/verify.go index 7c9bc92..6af5713 100644 --- a/cmd/tlsc/verify.go +++ b/cmd/tlsc/verify.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "fmt" @@ -9,7 +10,7 @@ import ( "github.com/dyaa/tlsc/internal/domain" ) -func runVerify(args []string) { +func runVerify(ctx context.Context, args []string) { fs := flag.NewFlagSet("verify", flag.ExitOnError) var ( @@ -39,7 +40,7 @@ func runVerify(args []string) { KeyPath: keyPath, } - result, err := svc.VerifyFile(certPath, opts) + result, err := svc.VerifyFile(ctx, certPath, opts) if err != nil { fatal("%s", err) } diff --git a/internal/adapters/http/hsts.go b/internal/adapters/http/hsts.go index 8b3b0b9..c954260 100644 --- a/internal/adapters/http/hsts.go +++ b/internal/adapters/http/hsts.go @@ -1,8 +1,10 @@ package http import ( + "context" "crypto/tls" "fmt" + "net" "net/http" "regexp" "strconv" @@ -14,7 +16,7 @@ import ( var maxAgeRe = regexp.MustCompile(`max-age=(\d+)`) -func FetchHSTS(host string, port int, timeout time.Duration, serverName string) *domain.HSTS { +func FetchHSTS(ctx context.Context, host string, port int, timeout time.Duration, serverName string) *domain.HSTS { client := &http.Client{ Timeout: timeout, Transport: &http.Transport{ @@ -28,8 +30,12 @@ func FetchHSTS(host string, port int, timeout time.Duration, serverName string) }, } - url := fmt.Sprintf("https://%s:%d", host, port) - resp, err := client.Head(url) + hstsURL := "https://" + net.JoinHostPort(host, fmt.Sprintf("%d", port)) + req, err := http.NewRequestWithContext(ctx, "HEAD", hstsURL, nil) + if err != nil { + return nil + } + resp, err := client.Do(req) if err != nil { return nil } diff --git a/internal/adapters/tls/connect.go b/internal/adapters/tls/connect.go index d4e6816..7b6656c 100644 --- a/internal/adapters/tls/connect.go +++ b/internal/adapters/tls/connect.go @@ -2,9 +2,12 @@ package tls import ( "bufio" + "context" "crypto/tls" + "encoding/asn1" "encoding/binary" "fmt" + "html" "io" "net" "time" @@ -12,7 +15,17 @@ import ( "github.com/dyaa/tlsc/internal/domain" ) -type StarttlsHandler func(conn net.Conn, scanner *bufio.Scanner) error +const ( + // maxMySQLPayload caps the MySQL greeting payload allocation to prevent + // a malicious server from causing excessive memory allocation. + maxMySQLPayload = 65536 + + // maxMultiLineLines caps the number of continuation lines read in + // multi-line protocol responses to prevent unbounded reads. + maxMultiLineLines = 100 +) + +type StarttlsHandler func(conn net.Conn, scanner *bufio.Scanner, serverName string) error var starttlsHandlers = map[domain.Protocol]StarttlsHandler{ domain.SMTP: smtpHandshake, @@ -26,30 +39,56 @@ var starttlsHandlers = map[domain.Protocol]StarttlsHandler{ domain.Sieve: sieveHandshake, } -func Dial(host string, port int, serverName string, timeout time.Duration) (*tls.Conn, error) { - dialer := &net.Dialer{Timeout: timeout} - return tls.DialWithDialer(dialer, "tcp", net.JoinHostPort(host, fmt.Sprintf("%d", port)), &tls.Config{ +// Dial connects to host:port with TLS. InsecureSkipVerify is enabled +// because this is an inspection tool — the caller performs certificate +// validation separately via the returned connection state. +func Dial(ctx context.Context, host string, port int, serverName string, timeout time.Duration) (*tls.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, &domain.ErrConnection{Host: host, Port: port, Err: err} + } + + tlsConn := tls.Client(conn, &tls.Config{ ServerName: serverName, InsecureSkipVerify: true, }) + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() + return nil, &domain.ErrTLSHandshake{Host: host, Err: err} + } + return tlsConn, nil } -func DialStartTLS(host string, port int, proto domain.Protocol, serverName string, timeout time.Duration) (*tls.Conn, error) { +// DialStartTLS performs a STARTTLS upgrade for the given protocol. +// InsecureSkipVerify is enabled because this is an inspection tool. +func DialStartTLS(ctx context.Context, host string, port int, proto domain.Protocol, serverName string, timeout time.Duration) (*tls.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - conn, err := net.DialTimeout("tcp", addr, timeout) + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { - return nil, err + return nil, &domain.ErrConnection{Host: host, Port: port, Err: err} + } + + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) } - conn.SetDeadline(time.Now().Add(timeout)) handler, ok := starttlsHandlers[proto] if !ok { conn.Close() - return nil, fmt.Errorf("unsupported STARTTLS protocol: %s", proto) + return nil, &domain.ErrUnsupportedProtocol{Protocol: proto} } scanner := bufio.NewScanner(conn) - if err := handler(conn, scanner); err != nil { + if err := handler(conn, scanner, serverName); err != nil { conn.Close() return nil, fmt.Errorf("%s STARTTLS handshake failed: %w", proto, err) } @@ -58,16 +97,16 @@ func DialStartTLS(host string, port int, proto domain.Protocol, serverName strin ServerName: serverName, InsecureSkipVerify: true, }) - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { tlsConn.Close() - return nil, err + return nil, &domain.ErrTLSHandshake{Host: host, Err: err} } conn.SetDeadline(time.Time{}) return tlsConn, nil } -func smtpHandshake(conn net.Conn, scanner *bufio.Scanner) error { +func smtpHandshake(conn net.Conn, scanner *bufio.Scanner, _ string) error { if err := readMultiLine(scanner); err != nil { return err } @@ -79,7 +118,7 @@ func smtpHandshake(conn net.Conn, scanner *bufio.Scanner) error { return readLine(scanner) } -func imapHandshake(conn net.Conn, scanner *bufio.Scanner) error { +func imapHandshake(conn net.Conn, scanner *bufio.Scanner, _ string) error { if err := readLine(scanner); err != nil { return err } @@ -87,7 +126,7 @@ func imapHandshake(conn net.Conn, scanner *bufio.Scanner) error { return readLine(scanner) } -func pop3Handshake(conn net.Conn, scanner *bufio.Scanner) error { +func pop3Handshake(conn net.Conn, scanner *bufio.Scanner, _ string) error { if err := readLine(scanner); err != nil { return err } @@ -95,7 +134,7 @@ func pop3Handshake(conn net.Conn, scanner *bufio.Scanner) error { return readLine(scanner) } -func ftpHandshake(conn net.Conn, scanner *bufio.Scanner) error { +func ftpHandshake(conn net.Conn, scanner *bufio.Scanner, _ string) error { if err := readLine(scanner); err != nil { return err } @@ -103,7 +142,7 @@ func ftpHandshake(conn net.Conn, scanner *bufio.Scanner) error { return readLine(scanner) } -func ldapHandshake(conn net.Conn, _ *bufio.Scanner) error { +func ldapHandshake(conn net.Conn, _ *bufio.Scanner, _ string) error { // BER-encoded ExtendedRequest for OID 1.3.6.1.4.1.1466.20037 oid := []byte{0x06, 0x17, 0x31, 0x2e, 0x33, 0x2e, 0x36, 0x2e, 0x31, 0x2e, 0x34, 0x2e, 0x31, 0x2e, 0x31, 0x34, 0x36, 0x36, 0x2e, 0x32, 0x30, @@ -119,19 +158,50 @@ func ldapHandshake(conn net.Conn, _ *bufio.Scanner) error { } buf := make([]byte, 1024) - if _, err := io.ReadAtLeast(conn, buf, 1); err != nil { + n, err := io.ReadAtLeast(conn, buf, 1) + if err != nil { return fmt.Errorf("failed to read LDAP STARTTLS response: %w", err) } + if err := validateLDAPResponse(buf[:n]); err != nil { + return fmt.Errorf("LDAP STARTTLS rejected: %w", err) + } + return nil } -func mysqlHandshake(conn net.Conn, _ *bufio.Scanner) error { +// validateLDAPResponse performs a minimal parse of an LDAP ExtendedResponse +// to verify that the resultCode is 0 (success). +func validateLDAPResponse(data []byte) error { + var msg struct { + MessageID int + Response asn1.RawValue + } + if _, err := asn1.Unmarshal(data, &msg); err != nil { + return fmt.Errorf("failed to parse LDAP response: %w", err) + } + + var resultCode asn1.Enumerated + if _, err := asn1.Unmarshal(msg.Response.Bytes, &resultCode); err != nil { + return fmt.Errorf("failed to parse result code: %w", err) + } + + if resultCode != 0 { + return fmt.Errorf("server returned result code %d", resultCode) + } + + return nil +} + +func mysqlHandshake(conn net.Conn, _ *bufio.Scanner, _ string) error { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return fmt.Errorf("failed to read MySQL greeting header: %w", err) } payloadLen := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 + if payloadLen > maxMySQLPayload { + return fmt.Errorf("MySQL greeting payload too large: %d bytes (max %d)", payloadLen, maxMySQLPayload) + } greeting := make([]byte, payloadLen) if _, err := io.ReadFull(conn, greeting); err != nil { return fmt.Errorf("failed to read MySQL greeting: %w", err) @@ -155,7 +225,7 @@ func mysqlHandshake(conn net.Conn, _ *bufio.Scanner) error { return nil } -func postgresHandshake(conn net.Conn, _ *bufio.Scanner) error { +func postgresHandshake(conn net.Conn, _ *bufio.Scanner, _ string) error { msg := make([]byte, 8) binary.BigEndian.PutUint32(msg[0:4], 8) binary.BigEndian.PutUint32(msg[4:8], 80877103) @@ -165,7 +235,7 @@ func postgresHandshake(conn net.Conn, _ *bufio.Scanner) error { } resp := make([]byte, 1) - if _, err := conn.Read(resp); err != nil { + if _, err := io.ReadFull(conn, resp); err != nil { return fmt.Errorf("failed to read PostgreSQL SSL response: %w", err) } @@ -176,8 +246,9 @@ func postgresHandshake(conn net.Conn, _ *bufio.Scanner) error { return nil } -func xmppHandshake(conn net.Conn, scanner *bufio.Scanner) error { - fmt.Fprintf(conn, "") +func xmppHandshake(conn net.Conn, _ *bufio.Scanner, serverName string) error { + escapedHost := html.EscapeString(serverName) + fmt.Fprintf(conn, "", escapedHost) buf := make([]byte, 4096) if _, err := io.ReadAtLeast(conn, buf, 1); err != nil { @@ -194,15 +265,20 @@ func xmppHandshake(conn net.Conn, scanner *bufio.Scanner) error { return nil } -func sieveHandshake(conn net.Conn, scanner *bufio.Scanner) error { - for scanner.Scan() { +func sieveHandshake(conn net.Conn, scanner *bufio.Scanner, _ string) error { + found := false + for i := 0; i < maxMultiLineLines && scanner.Scan(); i++ { line := scanner.Text() if line == "OK" || (len(line) >= 2 && line[:2] == "OK") { + found = true break } } - if err := scanner.Err(); err != nil { - return err + if !found { + if err := scanner.Err(); err != nil { + return err + } + return fmt.Errorf("did not receive Sieve OK greeting") } fmt.Fprintf(conn, "STARTTLS\r\n") @@ -220,7 +296,13 @@ func readLine(scanner *bufio.Scanner) error { } func readMultiLine(scanner *bufio.Scanner) error { - for scanner.Scan() { + for i := 0; i < maxMultiLineLines; i++ { + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return err + } + return fmt.Errorf("connection closed during multi-line read") + } line := scanner.Text() if len(line) >= 4 && line[3] == ' ' { return nil @@ -230,8 +312,5 @@ func readMultiLine(scanner *bufio.Scanner) error { } return nil } - if err := scanner.Err(); err != nil { - return err - } - return fmt.Errorf("connection closed during multi-line read") + return fmt.Errorf("too many continuation lines (max %d)", maxMultiLineLines) } diff --git a/internal/adapters/tls/crl.go b/internal/adapters/tls/crl.go index 275e9dc..afb4a6c 100644 --- a/internal/adapters/tls/crl.go +++ b/internal/adapters/tls/crl.go @@ -1,6 +1,7 @@ package tls import ( + "context" "crypto/x509" "io" "net/http" @@ -9,7 +10,7 @@ import ( "github.com/dyaa/tlsc/internal/domain" ) -func CheckCRL(leaf *x509.Certificate, issuer *x509.Certificate, timeout time.Duration) *domain.CRLStatus { +func CheckCRL(ctx context.Context, leaf *x509.Certificate, issuer *x509.Certificate, timeout time.Duration) *domain.CRLStatus { if len(leaf.CRLDistributionPoints) == 0 { return nil } @@ -18,10 +19,13 @@ func CheckCRL(leaf *x509.Certificate, issuer *x509.Certificate, timeout time.Dur return &domain.CRLStatus{Error: "no issuer certificate for CRL verification"} } - client := &http.Client{Timeout: timeout} + client := safeHTTPClient(timeout) for _, url := range leaf.CRLDistributionPoints { - if status, ok := checkSingleCRL(client, url, leaf, issuer); ok { + if err := isSafeURL(url); err != nil { + continue + } + if status, ok := checkSingleCRL(ctx, client, url, leaf, issuer); ok { return status } } @@ -29,14 +33,18 @@ func CheckCRL(leaf *x509.Certificate, issuer *x509.Certificate, timeout time.Dur return &domain.CRLStatus{Error: "could not fetch or verify any CRL"} } -func checkSingleCRL(client *http.Client, url string, leaf, issuer *x509.Certificate) (*domain.CRLStatus, bool) { - resp, err := client.Get(url) +func checkSingleCRL(ctx context.Context, client *http.Client, url string, leaf, issuer *x509.Certificate) (*domain.CRLStatus, bool) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, false + } + resp, err := client.Do(req) if err != nil { return nil, false } defer resp.Body.Close() - body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) if err != nil { return nil, false } diff --git a/internal/adapters/tls/ocsp.go b/internal/adapters/tls/ocsp.go index cef840b..81d429a 100644 --- a/internal/adapters/tls/ocsp.go +++ b/internal/adapters/tls/ocsp.go @@ -2,6 +2,7 @@ package tls import ( "bytes" + "context" "crypto/x509" "io" "net/http" @@ -10,7 +11,7 @@ import ( "github.com/dyaa/tlsc/internal/domain" ) -func CheckOCSP(leaf *x509.Certificate, issuer *x509.Certificate, stapledResponse []byte, timeout time.Duration) *domain.OCSPStatus { +func CheckOCSP(ctx context.Context, leaf *x509.Certificate, issuer *x509.Certificate, stapledResponse []byte, timeout time.Duration) *domain.OCSPStatus { if len(stapledResponse) > 0 { status, err := parseOCSPResponse(stapledResponse) if err == nil { @@ -35,6 +36,14 @@ func CheckOCSP(leaf *x509.Certificate, issuer *x509.Certificate, stapledResponse } } + ocspURL := leaf.OCSPServer[0] + if err := isSafeURL(ocspURL); err != nil { + return &domain.OCSPStatus{ + Status: "unknown", + Error: "blocked OCSP URL: " + err.Error(), + } + } + ocspReq, err := createOCSPRequest(leaf, issuer) if err != nil { return &domain.OCSPStatus{ @@ -43,8 +52,16 @@ func CheckOCSP(leaf *x509.Certificate, issuer *x509.Certificate, stapledResponse } } - client := &http.Client{Timeout: timeout} - resp, err := client.Post(leaf.OCSPServer[0], "application/ocsp-request", bytes.NewReader(ocspReq)) + client := safeHTTPClient(timeout) + req, err := http.NewRequestWithContext(ctx, "POST", ocspURL, bytes.NewReader(ocspReq)) + if err != nil { + return &domain.OCSPStatus{ + Status: "unknown", + Error: "failed to create OCSP request: " + err.Error(), + } + } + req.Header.Set("Content-Type", "application/ocsp-request") + resp, err := client.Do(req) if err != nil { return &domain.OCSPStatus{ Status: "unknown", diff --git a/internal/adapters/tls/safenet.go b/internal/adapters/tls/safenet.go new file mode 100644 index 0000000..d6d26c1 --- /dev/null +++ b/internal/adapters/tls/safenet.go @@ -0,0 +1,105 @@ +package tls + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// safeHTTPClient returns an http.Client that refuses connections to +// private, loopback, and link-local IP addresses, mitigating SSRF +// via certificate-embedded URLs (OCSP responders, CRL endpoints). +// It limits redirect following to 3 hops and blocks non-HTTP schemes. +func safeHTTPClient(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + DialContext: safeDialContext, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 3 { + return fmt.Errorf("too many redirects") + } + if err := validateURLScheme(req.URL); err != nil { + return err + } + return nil + }, + } +} + +// validateURLScheme verifies a URL uses http or https. +func validateURLScheme(u *url.URL) error { + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("blocked non-HTTP scheme: %s", scheme) + } + return nil +} + +// isSafeURL validates that a raw URL string uses an allowed scheme. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + return validateURLScheme(u) +} + +func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ips, err := net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + + var dialer net.Dialer + var lastErr error + allBlocked := true + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + if ip != nil && isPrivateIP(ip) { + continue + } + allBlocked = false + conn, dialErr := dialer.DialContext(ctx, network, net.JoinHostPort(ipStr, port)) + if dialErr == nil { + return conn, nil + } + lastErr = dialErr + } + + if allBlocked { + return nil, fmt.Errorf("blocked connection to private/reserved address for %s", host) + } + return nil, lastErr +} + +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() { + return true + } + if ip4 := ip.To4(); ip4 != nil { + // 100.64.0.0/10 — Carrier-grade NAT / shared address space (RFC 6598) + if ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127 { + return true + } + } + // 6to4 (2002::/16) — embeds IPv4 address; 2002:c0a8:01:: = 192.168.1.0 + if len(ip) == net.IPv6len && ip[0] == 0x20 && ip[1] == 0x02 { + return true + } + // Teredo (2001:0000::/32) — tunneling protocol that can reach private IPv4 + if len(ip) == net.IPv6len && ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 { + return true + } + return false +} diff --git a/internal/adapters/tls/safenet_test.go b/internal/adapters/tls/safenet_test.go new file mode 100644 index 0000000..64b2b42 --- /dev/null +++ b/internal/adapters/tls/safenet_test.go @@ -0,0 +1,129 @@ +package tls + +import ( + "net" + "net/url" + "testing" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + blocked bool + }{ + // IPv4 loopback + {"IPv4 loopback", "127.0.0.1", true}, + + // IPv4 private ranges (RFC 1918) + {"IPv4 private 10.x", "10.0.0.1", true}, + {"IPv4 private 172.16.x", "172.16.0.1", true}, + {"IPv4 private 192.168.x", "192.168.1.1", true}, + + // IPv4 link-local + {"IPv4 link-local", "169.254.1.1", true}, + + // IPv4 unspecified + {"IPv4 unspecified", "0.0.0.0", true}, + + // IPv4 CGN / shared address space (RFC 6598: 100.64.0.0/10) + {"IPv4 CGN low", "100.64.0.1", true}, + {"IPv4 CGN high", "100.127.255.255", true}, + + // IPv4 public + {"IPv4 public Google DNS", "8.8.8.8", false}, + {"IPv4 public Cloudflare", "1.1.1.1", false}, + + // IPv6 loopback + {"IPv6 loopback", "::1", true}, + + // IPv6 private (unique local) + {"IPv6 unique local fd00::", "fd00::1", true}, + + // IPv6 link-local + {"IPv6 link-local", "fe80::1", true}, + + // IPv6 6to4 (2002::/16) + {"IPv6 6to4", "2002:c0a8:0001::", true}, + + // IPv6 Teredo (2001:0000::/32) + {"IPv6 Teredo", "2001:0000:4136:e378::", true}, + + // IPv6 multicast + {"IPv6 multicast", "ff02::1", true}, + + // IPv6 public + {"IPv6 public Google", "2607:f8b0:4004:800::200e", false}, + + // IPv4-mapped IPv6 + {"IPv4-mapped IPv6 loopback", "::ffff:127.0.0.1", true}, + {"IPv4-mapped IPv6 private", "::ffff:10.0.0.1", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP %q", tt.ip) + } + got := isPrivateIP(ip) + if got != tt.blocked { + t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, got, tt.blocked) + } + }) + } +} + +func TestIsSafeURL(t *testing.T) { + tests := []struct { + name string + rawURL string + wantError bool + }{ + {"http scheme allowed", "http://example.com", false}, + {"https scheme allowed", "https://example.com", false}, + {"ftp scheme blocked", "ftp://example.com", true}, + {"gopher scheme blocked", "gopher://example.com", true}, + {"javascript scheme blocked", "javascript:alert(1)", true}, + {"empty string blocked", "", true}, + {"no scheme blocked", "//example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isSafeURL(tt.rawURL) + if (err != nil) != tt.wantError { + t.Errorf("isSafeURL(%q) error = %v, wantError %v", tt.rawURL, err, tt.wantError) + } + }) + } +} + +func TestValidateURLScheme(t *testing.T) { + tests := []struct { + name string + rawURL string + wantError bool + }{ + {"http allowed", "http://example.com", false}, + {"https allowed", "https://example.com", false}, + {"ftp blocked", "ftp://example.com", true}, + {"gopher blocked", "gopher://example.com", true}, + {"javascript blocked", "javascript:alert(1)", true}, + {"empty scheme blocked", "example.com", true}, + {"no scheme blocked", "//example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.rawURL) + if err != nil { + t.Fatalf("url.Parse(%q) failed: %v", tt.rawURL, err) + } + err = validateURLScheme(u) + if (err != nil) != tt.wantError { + t.Errorf("validateURLScheme(%q) error = %v, wantError %v", tt.rawURL, err, tt.wantError) + } + }) + } +} diff --git a/internal/adapters/truststore/truststore.go b/internal/adapters/truststore/truststore.go index e846db9..e2e63c3 100644 --- a/internal/adapters/truststore/truststore.go +++ b/internal/adapters/truststore/truststore.go @@ -5,6 +5,8 @@ import ( "os" "os/exec" "runtime" + + "github.com/dyaa/tlsc/internal/fileutil" ) func Install(certPath string) error { @@ -44,7 +46,7 @@ func installDarwin(certPath string) error { } func installLinux(certPath string) error { - data, err := os.ReadFile(certPath) + data, err := fileutil.ReadLimited(certPath, fileutil.MaxCertFileSize) if err != nil { return err } diff --git a/internal/domain/errors.go b/internal/domain/errors.go new file mode 100644 index 0000000..90501d8 --- /dev/null +++ b/internal/domain/errors.go @@ -0,0 +1,92 @@ +package domain + +import "fmt" + +// ErrConnection represents a failure to establish a network connection. +type ErrConnection struct { + Host string + Port int + Err error +} + +func (e *ErrConnection) Error() string { + return fmt.Sprintf("connection to %s:%d failed: %v", e.Host, e.Port, e.Err) +} + +func (e *ErrConnection) Unwrap() error { return e.Err } + +// ErrTLSHandshake represents a failure during TLS negotiation. +type ErrTLSHandshake struct { + Host string + Err error +} + +func (e *ErrTLSHandshake) Error() string { + return fmt.Sprintf("TLS handshake with %s failed: %v", e.Host, e.Err) +} + +func (e *ErrTLSHandshake) Unwrap() error { return e.Err } + +// ErrCertificateInvalid indicates the certificate failed validation. +type ErrCertificateInvalid struct { + Host string + Reason string +} + +func (e *ErrCertificateInvalid) Error() string { + if e.Host != "" { + return fmt.Sprintf("certificate for %s is invalid: %s", e.Host, e.Reason) + } + return fmt.Sprintf("certificate is invalid: %s", e.Reason) +} + +// ErrCANotFound indicates no CA was found at the expected path. +type ErrCANotFound struct { + Path string +} + +func (e *ErrCANotFound) Error() string { + return fmt.Sprintf("CA not found at %s (run 'tlsc ca init' first)", e.Path) +} + +// ErrCAExists indicates a CA already exists at the path. +type ErrCAExists struct { + Path string +} + +func (e *ErrCAExists) Error() string { + return fmt.Sprintf("CA already exists at %s (use 'ca info' to view)", e.Path) +} + +// ErrFileRead represents a failure to read a file. +type ErrFileRead struct { + Path string + Err error +} + +func (e *ErrFileRead) Error() string { + return fmt.Sprintf("failed to read %s: %v", e.Path, e.Err) +} + +func (e *ErrFileRead) Unwrap() error { return e.Err } + +// ErrInvalidPEM indicates the file is not valid PEM format. +type ErrInvalidPEM struct { + Path string +} + +func (e *ErrInvalidPEM) Error() string { + if e.Path != "" { + return fmt.Sprintf("not a valid PEM file: %s", e.Path) + } + return "not a valid PEM file" +} + +// ErrUnsupportedProtocol indicates the STARTTLS protocol is not supported. +type ErrUnsupportedProtocol struct { + Protocol Protocol +} + +func (e *ErrUnsupportedProtocol) Error() string { + return fmt.Sprintf("unsupported STARTTLS protocol: %s", e.Protocol) +} diff --git a/internal/domain/models.go b/internal/domain/models.go index ca4802f..a3b1a17 100644 --- a/internal/domain/models.go +++ b/internal/domain/models.go @@ -6,21 +6,33 @@ import ( "time" ) +// Protocol represents a connection protocol used for TLS certificate checks. type Protocol string const ( - HTTPS Protocol = "https" - SMTP Protocol = "smtp" - IMAP Protocol = "imap" - POP3 Protocol = "pop3" - FTP Protocol = "ftp" - LDAP Protocol = "ldap" - MySQL Protocol = "mysql" + // HTTPS is the HTTPS protocol (default port 443). + HTTPS Protocol = "https" + // SMTP is the SMTP protocol with STARTTLS (default port 587). + SMTP Protocol = "smtp" + // IMAP is the IMAP protocol with STARTTLS (default port 143). + IMAP Protocol = "imap" + // POP3 is the POP3 protocol with STARTTLS (default port 110). + POP3 Protocol = "pop3" + // FTP is the FTP protocol with explicit TLS (default port 21). + FTP Protocol = "ftp" + // LDAP is the LDAP protocol with STARTTLS (default port 389). + LDAP Protocol = "ldap" + // MySQL is the MySQL protocol with TLS negotiation (default port 3306). + MySQL Protocol = "mysql" + // Postgres is the PostgreSQL protocol with TLS negotiation (default port 5432). Postgres Protocol = "postgres" - XMPP Protocol = "xmpp" - Sieve Protocol = "sieve" + // XMPP is the XMPP protocol with STARTTLS (default port 5222). + XMPP Protocol = "xmpp" + // Sieve is the ManageSieve protocol with STARTTLS (default port 4190). + Sieve Protocol = "sieve" ) +// DefaultPorts maps each Protocol to its standard port number. var DefaultPorts = map[Protocol]int{ HTTPS: 443, SMTP: 587, @@ -34,12 +46,17 @@ var DefaultPorts = map[Protocol]int{ Sieve: 4190, } +// Subject holds the distinguished name fields of a certificate subject or issuer. type Subject struct { + // CN is the Common Name. CN string `json:"CN"` - O string `json:"O,omitempty"` - C string `json:"C,omitempty"` + // O is the Organization. + O string `json:"O,omitempty"` + // C is the Country code. + C string `json:"C,omitempty"` } +// ChainCert represents a single certificate in the trust chain. type ChainCert struct { Subject Subject `json:"subject"` Issuer Subject `json:"issuer"` @@ -49,6 +66,7 @@ type ChainCert struct { SerialNumber string `json:"serialNumber"` } +// HSTS holds HTTP Strict Transport Security header information for a host. type HSTS struct { Enabled bool `json:"enabled"` MaxAge int `json:"maxAge,omitempty"` @@ -56,28 +74,37 @@ type HSTS struct { Preload bool `json:"preload,omitempty"` } +// Grade represents the overall TLS security grade and associated findings. type Grade struct { - Grade string `json:"grade"` - Protocols []string `json:"protocols"` - WeakCiphers bool `json:"weakCiphers"` - Reasons []string `json:"reasons"` + // Grade is the letter grade (e.g. "A+", "B", "F"). + Grade string `json:"grade"` + // Protocols lists the TLS protocol versions supported by the server. + Protocols []string `json:"protocols"` + // WeakCiphers indicates whether any weak cipher suites were detected. + WeakCiphers bool `json:"weakCiphers"` + // Reasons lists human-readable explanations for any grade deductions. + Reasons []string `json:"reasons"` } +// OCSPStatus holds the OCSP revocation check result for a certificate. type OCSPStatus struct { Status string `json:"status"` Stapled bool `json:"stapled"` Error string `json:"error,omitempty"` } +// CRLStatus holds the CRL revocation check result for a certificate. type CRLStatus struct { Revoked bool `json:"revoked"` Error string `json:"error,omitempty"` } +// SCTInfo holds Signed Certificate Timestamp information from Certificate Transparency logs. type SCTInfo struct { Count int `json:"count"` } +// CertDetails holds extended certificate metadata such as key usage, OCSP servers, and SANs. type CertDetails struct { SignatureAlgorithm string `json:"signatureAlgorithm"` KeyUsage []string `json:"keyUsage,omitempty"` @@ -89,6 +116,7 @@ type CertDetails struct { SANs *SANs `json:"sans,omitempty"` } +// SANs holds the Subject Alternative Names from a certificate. type SANs struct { DNSNames []string `json:"dnsNames,omitempty"` IPAddresses []string `json:"ipAddresses,omitempty"` @@ -96,9 +124,10 @@ type SANs struct { URIs []string `json:"uris,omitempty"` } +// Result holds the full inspection outcome for a single TLS certificate. type Result struct { Valid bool `json:"valid"` - ValidationError string `json:"validationError"` + ValidationError string `json:"validationError,omitempty"` ValidFrom string `json:"validFrom"` ValidTo string `json:"validTo"` DaysRemaining int `json:"daysRemaining"` @@ -121,20 +150,29 @@ type Result struct { Details *CertDetails `json:"details,omitempty"` } +// ResultOrError pairs a Result with an optional error, used in batch operations. type ResultOrError struct { Result *Result Err error } +// CheckOptions configures how a TLS certificate check is performed. type CheckOptions struct { - Port int - Protocol Protocol - Timeout time.Duration + // Port overrides the default port for the protocol. + Port int + // Protocol selects the connection protocol (e.g. HTTPS, SMTP). + Protocol Protocol + // Timeout is the maximum duration for the full connection (TCP dial + TLS handshake). + Timeout time.Duration + // ServerName overrides the SNI value sent during the TLS handshake. ServerName string - WarnDays *int - DoGrade bool + // WarnDays sets the threshold for flagging certificates as expiring soon. + WarnDays *int + // DoGrade enables TLS security grading in the result. + DoGrade bool } +// DefaultCheckOptions returns CheckOptions with sensible defaults (HTTPS, 10s timeout). func DefaultCheckOptions() CheckOptions { return CheckOptions{ Protocol: HTTPS, @@ -142,6 +180,7 @@ func DefaultCheckOptions() CheckOptions { } } +// EffectivePort returns the configured port, falling back to the protocol's default port. func (o CheckOptions) EffectivePort() int { if o.Port > 0 { return o.Port @@ -152,6 +191,7 @@ func (o CheckOptions) EffectivePort() int { return 443 } +// EffectiveServerName returns the configured SNI server name, falling back to the given host. func (o CheckOptions) EffectiveServerName(host string) string { if o.ServerName != "" { return o.ServerName @@ -159,14 +199,21 @@ func (o CheckOptions) EffectiveServerName(host string) string { return host } +// CAInitOptions configures the creation of a new Certificate Authority. type CAInitOptions struct { - CommonName string `json:"commonName"` + // CommonName is the CA certificate's CN field. + CommonName string `json:"commonName"` + // Organization is the CA certificate's O field. Organization string `json:"organization"` - Country string `json:"country,omitempty"` - Validity int `json:"validity"` - KeyType string `json:"keyType"` + // Country is the optional two-letter country code. + Country string `json:"country,omitempty"` + // Validity is the CA lifetime in years. + Validity int `json:"validity"` + // KeyType selects the key algorithm ("ecdsa" or "rsa"). + KeyType string `json:"keyType"` } +// DefaultCAInitOptions returns CAInitOptions with sensible defaults (ECDSA, 10-year validity). func DefaultCAInitOptions() CAInitOptions { return CAInitOptions{ CommonName: "tlsc CA", @@ -176,26 +223,44 @@ func DefaultCAInitOptions() CAInitOptions { } } +// GenerateOptions configures certificate or CSR generation. type GenerateOptions struct { - Days int - KeyType string - OutDir string - Server bool - Client bool - CAPath string + // Days is the certificate validity period in days. + Days int + // KeyType selects the key algorithm ("ecdsa" or "rsa"). + KeyType string + // OutDir is the directory where generated files are written. + OutDir string + // Server enables the TLS server extended key usage. + Server bool + // Client enables the TLS client extended key usage. + Client bool + // CAPath is the path to the CA used for signing; empty means self-signed. + CAPath string + // Organization sets the O field in the certificate subject. Organization string - Country string - OrgUnit string - Locality string - State string - Bundle bool + // Country sets the C field in the certificate subject. + Country string + // OrgUnit sets the OU field in the certificate subject. + OrgUnit string + // Locality sets the L field in the certificate subject. + Locality string + // State sets the ST field in the certificate subject. + State string + // Bundle appends the CA certificate to the output certificate file. + Bundle bool } +// DefaultCertsPath returns the default output directory for generated certificates. func DefaultCertsPath() string { - home, _ := os.UserHomeDir() + home, err := os.UserHomeDir() + if err != nil || home == "" { + return filepath.Join(os.TempDir(), ".tlsc", "certs") + } return filepath.Join(home, ".tlsc", "certs") } +// DefaultGenerateOptions returns GenerateOptions with sensible defaults (ECDSA, 825 days, server usage). func DefaultGenerateOptions() GenerateOptions { return GenerateOptions{ Days: 825, @@ -205,6 +270,7 @@ func DefaultGenerateOptions() GenerateOptions { } } +// GenerateResult holds the output paths and metadata from certificate generation. type GenerateResult struct { CertPath string KeyPath string @@ -212,17 +278,20 @@ type GenerateResult struct { ValidUntil time.Time } +// CSRResult holds the output paths from CSR generation. type CSRResult struct { CSRPath string KeyPath string Hosts []string } +// VerifyOptions configures certificate verification against a CA and optional private key. type VerifyOptions struct { CAPath string KeyPath string } +// VerifyResult holds the outcome of a certificate verification check. type VerifyResult struct { Valid bool `json:"valid"` Chain bool `json:"chain"` @@ -232,6 +301,7 @@ type VerifyResult struct { Issuer string `json:"issuer"` } +// CertSummary provides a brief overview of a certificate for listing purposes. type CertSummary struct { Path string `json:"path"` Subject string `json:"subject"` diff --git a/internal/fileutil/fileutil.go b/internal/fileutil/fileutil.go new file mode 100644 index 0000000..767d6f3 --- /dev/null +++ b/internal/fileutil/fileutil.go @@ -0,0 +1,30 @@ +package fileutil + +import ( + "fmt" + "io" + "os" +) + +// MaxCertFileSize is the upper bound for reading certificate, key, +// and config files. Prevents OOM from maliciously large files. +const MaxCertFileSize int64 = 1 << 20 // 1 MB + +// ReadLimited reads a file up to maxBytes. Returns an error if the +// file exceeds the limit. +func ReadLimited(path string, maxBytes int64) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + data, err := io.ReadAll(io.LimitReader(f, maxBytes+1)) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("file too large: exceeds %d bytes", maxBytes) + } + return data, nil +} diff --git a/internal/fileutil/fileutil_test.go b/internal/fileutil/fileutil_test.go new file mode 100644 index 0000000..cfa2da7 --- /dev/null +++ b/internal/fileutil/fileutil_test.go @@ -0,0 +1,92 @@ +package fileutil + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestReadLimited(t *testing.T) { + const limit int64 = MaxCertFileSize // 1 MB + + t.Run("normal file under limit", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "small.txt") + data := make([]byte, 100) + for i := range data { + data[i] = 'A' + } + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatal(err) + } + + got, err := ReadLimited(path, limit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 100 { + t.Errorf("got %d bytes, want 100", len(got)) + } + }) + + t.Run("file at exactly the limit", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "exact.bin") + data := make([]byte, limit) + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatal(err) + } + + got, err := ReadLimited(path, limit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if int64(len(got)) != limit { + t.Errorf("got %d bytes, want %d", len(got), limit) + } + }) + + t.Run("file over the limit", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "oversized.bin") + data := make([]byte, limit+1) + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatal(err) + } + + _, err := ReadLimited(path, limit) + if err == nil { + t.Fatal("expected error for oversized file, got nil") + } + if !strings.Contains(err.Error(), "too large") { + t.Errorf("error should contain 'too large', got: %v", err) + } + }) + + t.Run("non-existent file", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "does_not_exist.txt") + + _, err := ReadLimited(path, limit) + if err == nil { + t.Fatal("expected error for non-existent file, got nil") + } + }) + + t.Run("empty file", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.txt") + if err := os.WriteFile(path, []byte{}, 0o644); err != nil { + t.Fatal(err) + } + + got, err := ReadLimited(path, limit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 0 { + t.Errorf("got %d bytes, want 0", len(got)) + } + }) +} diff --git a/internal/output/pretty.go b/internal/output/pretty.go index 4fddf4e..250866f 100644 --- a/internal/output/pretty.go +++ b/internal/output/pretty.go @@ -2,10 +2,22 @@ package output import ( "fmt" + "strings" "github.com/dyaa/tlsc/internal/domain" ) +// sanitize strips control characters that could be used for terminal +// escape sequence injection via untrusted certificate fields. +func sanitize(s string) string { + return strings.Map(func(r rune) rune { + if r < 0x20 || r == 0x7F { + return -1 + } + return r + }, s) +} + const ( reset = "\033[0m" bold = "\033[1m" @@ -20,19 +32,19 @@ func Pretty(host string, r *domain.Result) { status = red + "✗ Invalid" + reset } - fmt.Printf("\n SSL Certificate for %s%s%s\n\n", bold, host, reset) + fmt.Printf("\n SSL Certificate for %s%s%s\n\n", bold, sanitize(host), reset) fmt.Printf(" Status: %s\n", status) if r.ValidationError != "" { - fmt.Printf(" Error: %s%s%s\n", red, r.ValidationError, reset) + fmt.Printf(" Error: %s%s%s\n", red, sanitize(r.ValidationError), reset) } fmt.Printf(" Days Remaining: %d\n", r.DaysRemaining) fmt.Printf(" Valid From: %s\n", r.ValidFrom) fmt.Printf(" Valid To: %s\n", r.ValidTo) - fmt.Printf(" Subject: %s\n", r.Subject.CN) + fmt.Printf(" Subject: %s\n", sanitize(r.Subject.CN)) - issuerStr := r.Issuer.CN + issuerStr := sanitize(r.Issuer.CN) if r.Issuer.O != "" { - issuerStr += " (" + r.Issuer.O + ")" + issuerStr += " (" + sanitize(r.Issuer.O) + ")" } fmt.Printf(" Issuer: %s\n", issuerStr) if r.Protocol != "" { @@ -53,7 +65,7 @@ func Pretty(host string, r *domain.Result) { if i > 0 { fmt.Print(", ") } - fmt.Print(v) + fmt.Print(sanitize(v)) } fmt.Println() } @@ -63,11 +75,11 @@ func Pretty(host string, r *domain.Result) { chainStatus = yellow + "✗ Incomplete" + reset } fmt.Printf(" Chain: %s (%d certificates)\n", chainStatus, len(r.Chain)+1) - fmt.Printf(" %s\n", r.Subject.CN) + fmt.Printf(" %s\n", sanitize(r.Subject.CN)) for _, link := range r.Chain { - linkStr := link.Subject.CN + linkStr := sanitize(link.Subject.CN) if link.Subject.O != "" { - linkStr += " (" + link.Subject.O + ")" + linkStr += " (" + sanitize(link.Subject.O) + ")" } fmt.Printf(" └─ %s\n", linkStr) } @@ -99,13 +111,13 @@ func Pretty(host string, r *domain.Result) { } fmt.Printf(" OCSP: %s%s%s%s\n", ocspColor, r.OCSP.Status, reset, stapled) if r.OCSP.Error != "" { - fmt.Printf(" %s%s%s\n", yellow, r.OCSP.Error, reset) + fmt.Printf(" %s%s%s\n", yellow, sanitize(r.OCSP.Error), reset) } } if r.CRL != nil { if r.CRL.Error != "" { - fmt.Printf(" CRL: %s%s%s\n", yellow, r.CRL.Error, reset) + fmt.Printf(" CRL: %s%s%s\n", yellow, sanitize(r.CRL.Error), reset) } else if r.CRL.Revoked { fmt.Printf(" CRL: %srevoked%s\n", red, reset) } else { @@ -118,14 +130,14 @@ func Pretty(host string, r *domain.Result) { } if r.Details != nil { - fmt.Printf(" Signature: %s\n", r.Details.SignatureAlgorithm) + fmt.Printf(" Signature: %s\n", sanitize(r.Details.SignatureAlgorithm)) if len(r.Details.ExtKeyUsage) > 0 { fmt.Printf(" Key Usage: ") for i, u := range r.Details.ExtKeyUsage { if i > 0 { fmt.Print(", ") } - fmt.Print(u) + fmt.Print(sanitize(u)) } fmt.Println() } diff --git a/internal/output/sanitize_test.go b/internal/output/sanitize_test.go new file mode 100644 index 0000000..1caa879 --- /dev/null +++ b/internal/output/sanitize_test.go @@ -0,0 +1,76 @@ +package output + +import "testing" + +func TestSanitize(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "normal string unchanged", + input: "example.com", + want: "example.com", + }, + { + name: "ANSI escape sequences stripped", + input: "\x1b[31mREDTEXT\x1b[0m", + want: "[31mREDTEXT[0m", + }, + { + name: "null byte stripped", + input: "test\x00name", + want: "testname", + }, + { + name: "bell character stripped", + input: "test\x07name", + want: "testname", + }, + { + name: "newline stripped", + input: "line1\nline2", + want: "line1line2", + }, + { + name: "carriage return stripped", + input: "line1\rline2", + want: "line1line2", + }, + { + name: "tab stripped", + input: "col1\tcol2", + want: "col1col2", + }, + { + name: "DEL character stripped", + input: "test\x7fname", + want: "testname", + }, + { + name: "unicode preserved", + input: "日本語テスト", + want: "日本語テスト", + }, + { + name: "mixed control chars stripped", + input: "CN=\x1b[2Jtest\x00.com", + want: "CN=[2Jtest.com", + }, + { + name: "empty string unchanged", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitize(tt.input) + if got != tt.want { + t.Errorf("sanitize(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/ports/checker.go b/internal/ports/checker.go index d3098ad..4b51baa 100644 --- a/internal/ports/checker.go +++ b/internal/ports/checker.go @@ -1,12 +1,16 @@ package ports -import "github.com/dyaa/tlsc/internal/domain" +import ( + "context" + + "github.com/dyaa/tlsc/internal/domain" +) type Checker interface { - Check(host string, opts domain.CheckOptions) (*domain.Result, error) - CheckBatch(hosts []string, opts domain.CheckOptions) map[string]*domain.ResultOrError + Check(ctx context.Context, host string, opts domain.CheckOptions) (*domain.Result, error) + CheckBatch(ctx context.Context, hosts []string, opts domain.CheckOptions) map[string]*domain.ResultOrError } type Inspector interface { - InspectFile(path string) (*domain.Result, error) + InspectFile(ctx context.Context, path string) (*domain.Result, error) } diff --git a/internal/services/ca/ca.go b/internal/services/ca/ca.go index e605bcd..5e23d43 100644 --- a/internal/services/ca/ca.go +++ b/internal/services/ca/ca.go @@ -1,6 +1,7 @@ package ca import ( + "context" "crypto" "crypto/rand" "crypto/sha256" @@ -17,29 +18,41 @@ import ( tlscrypto "github.com/dyaa/tlsc/internal/crypto" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" ) +// CA represents a loaded Certificate Authority with its certificate, private key, and on-disk path. type CA struct { - Cert *x509.Certificate - Key crypto.Signer + // Cert is the parsed X.509 CA certificate. + Cert *x509.Certificate + // Key is the CA's private signing key. + Key crypto.Signer + // CertPEM is the PEM-encoded CA certificate bytes. CertPEM []byte - Path string - Config *domain.CAInitOptions + // Path is the directory on disk where the CA files are stored. + Path string + // Config holds the options used to create this CA, if available. + Config *domain.CAInitOptions } +// DefaultPath returns the default filesystem path for the Certificate Authority directory. func DefaultPath() string { - home, _ := os.UserHomeDir() + home, err := os.UserHomeDir() + if err != nil || home == "" { + return filepath.Join(os.TempDir(), ".tlsc", "ca") + } return filepath.Join(home, ".tlsc", "ca") } -func Init(caPath string, opts domain.CAInitOptions) (*CA, error) { +// Init creates a new root Certificate Authority at caPath, generating a key pair and self-signed certificate. +func Init(ctx context.Context, caPath string, opts domain.CAInitOptions) (*CA, error) { if caPath == "" { caPath = DefaultPath() } certPath := filepath.Join(caPath, "rootCA.pem") if _, err := os.Stat(certPath); err == nil { - return nil, fmt.Errorf("CA already exists at %s (use 'ca info' to view)", caPath) + return nil, &domain.ErrCAExists{Path: caPath} } if err := os.MkdirAll(caPath, 0700); err != nil { @@ -106,21 +119,22 @@ func Init(caPath string, opts domain.CAInitOptions) (*CA, error) { if err != nil { return nil, fmt.Errorf("failed to marshal CA config: %w", err) } - if err := os.WriteFile(filepath.Join(caPath, "ca-config.json"), configData, 0644); err != nil { + if err := os.WriteFile(filepath.Join(caPath, "ca-config.json"), configData, 0600); err != nil { return nil, fmt.Errorf("failed to write CA config: %w", err) } return &CA{Cert: cert, Key: key, CertPEM: certPEM, Path: caPath, Config: &opts}, nil } -func Load(caPath string) (*CA, error) { +// Load reads an existing Certificate Authority from caPath, parsing the certificate and private key. +func Load(ctx context.Context, caPath string) (*CA, error) { if caPath == "" { caPath = DefaultPath() } - certPEM, err := os.ReadFile(filepath.Join(caPath, "rootCA.pem")) + certPEM, err := fileutil.ReadLimited(filepath.Join(caPath, "rootCA.pem"), fileutil.MaxCertFileSize) if err != nil { - return nil, fmt.Errorf("CA not found at %s (run 'tlsc ca init' first)", caPath) + return nil, &domain.ErrCANotFound{Path: caPath} } block, _ := pem.Decode(certPEM) @@ -133,7 +147,7 @@ func Load(caPath string) (*CA, error) { return nil, fmt.Errorf("failed to parse CA certificate: %w", err) } - keyPEM, err := os.ReadFile(filepath.Join(caPath, "rootCA-key.pem")) + keyPEM, err := fileutil.ReadLimited(filepath.Join(caPath, "rootCA-key.pem"), fileutil.MaxCertFileSize) if err != nil { return nil, fmt.Errorf("failed to read CA key: %w", err) } @@ -154,7 +168,11 @@ func Load(caPath string) (*CA, error) { if parseErr != nil { return nil, fmt.Errorf("unsupported key type: %s", keyBlock.Type) } - key = pk.(crypto.Signer) + signer, ok := pk.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("PKCS#8 key does not implement crypto.Signer") + } + key = signer } if err != nil { return nil, fmt.Errorf("failed to parse CA key: %w", err) @@ -162,7 +180,7 @@ func Load(caPath string) (*CA, error) { result := &CA{Cert: cert, Key: key, CertPEM: certPEM, Path: caPath} - configData, err := os.ReadFile(filepath.Join(caPath, "ca-config.json")) + configData, err := fileutil.ReadLimited(filepath.Join(caPath, "ca-config.json"), fileutil.MaxCertFileSize) if err == nil { var config domain.CAInitOptions if json.Unmarshal(configData, &config) == nil { @@ -173,6 +191,7 @@ func Load(caPath string) (*CA, error) { return result, nil } +// Fingerprint returns the SHA-256 fingerprint of the CA certificate as a colon-separated hex string. func (c *CA) Fingerprint() string { sum := sha256.Sum256(c.Cert.Raw) parts := make([]string, len(sum)) @@ -182,6 +201,7 @@ func (c *CA) Fingerprint() string { return strings.Join(parts, ":") } +// CertPath returns the absolute path to the CA's root certificate PEM file. func (c *CA) CertPath() string { return filepath.Join(c.Path, "rootCA.pem") } diff --git a/internal/services/ca/ca_test.go b/internal/services/ca/ca_test.go index 2e09d2b..c300550 100644 --- a/internal/services/ca/ca_test.go +++ b/internal/services/ca/ca_test.go @@ -1,6 +1,7 @@ package ca import ( + "context" "encoding/json" "os" "path/filepath" @@ -12,7 +13,8 @@ import ( func TestInitAndLoad(t *testing.T) { dir := t.TempDir() - authority, err := Init(dir, domain.DefaultCAInitOptions()) + ctx := context.Background() + authority, err := Init(ctx, dir, domain.DefaultCAInitOptions()) if err != nil { t.Fatal(err) } @@ -53,7 +55,7 @@ func TestInitAndLoad(t *testing.T) { } // Load should return same CA - loaded, err := Load(dir) + loaded, err := Load(ctx, dir) if err != nil { t.Fatal(err) } @@ -82,7 +84,8 @@ func TestInitCustomOptions(t *testing.T) { KeyType: "ecdsa", } - authority, err := Init(dir, opts) + ctx := context.Background() + authority, err := Init(ctx, dir, opts) if err != nil { t.Fatal(err) } @@ -122,7 +125,8 @@ func TestInitKeyTypes(t *testing.T) { opts := domain.DefaultCAInitOptions() opts.KeyType = kt - authority, err := Init(dir, opts) + ctx := context.Background() + authority, err := Init(ctx, dir, opts) if err != nil { t.Fatalf("key type %s failed: %v", kt, err) } @@ -131,7 +135,7 @@ func TestInitKeyTypes(t *testing.T) { } // Verify it can be loaded back - loaded, err := Load(dir) + loaded, err := Load(ctx, dir) if err != nil { t.Fatalf("failed to load %s CA: %v", kt, err) } @@ -145,18 +149,19 @@ func TestInitKeyTypes(t *testing.T) { func TestInitAlreadyExists(t *testing.T) { dir := t.TempDir() - if _, err := Init(dir, domain.DefaultCAInitOptions()); err != nil { + ctx := context.Background() + if _, err := Init(ctx, dir, domain.DefaultCAInitOptions()); err != nil { t.Fatal(err) } - _, err := Init(dir, domain.DefaultCAInitOptions()) + _, err := Init(ctx, dir, domain.DefaultCAInitOptions()) if err == nil { t.Error("expected error for existing CA") } } func TestLoadNotFound(t *testing.T) { - _, err := Load(t.TempDir()) + _, err := Load(context.Background(), t.TempDir()) if err == nil { t.Error("expected error for missing CA") } @@ -164,7 +169,7 @@ func TestLoadNotFound(t *testing.T) { func TestFingerprint(t *testing.T) { dir := t.TempDir() - authority, _ := Init(dir, domain.DefaultCAInitOptions()) + authority, _ := Init(context.Background(), dir, domain.DefaultCAInitOptions()) fp := authority.Fingerprint() if len(fp) != 95 { @@ -174,7 +179,7 @@ func TestFingerprint(t *testing.T) { func TestCertPath(t *testing.T) { dir := t.TempDir() - authority, _ := Init(dir, domain.DefaultCAInitOptions()) + authority, _ := Init(context.Background(), dir, domain.DefaultCAInitOptions()) expected := filepath.Join(dir, "rootCA.pem") if got := authority.CertPath(); got != expected { diff --git a/internal/services/ca/renew.go b/internal/services/ca/renew.go index d3840f3..b179261 100644 --- a/internal/services/ca/renew.go +++ b/internal/services/ca/renew.go @@ -1,6 +1,7 @@ package ca import ( + "context" "crypto/rand" "crypto/x509" "encoding/pem" @@ -9,14 +10,17 @@ import ( "os" "path/filepath" "time" + + "github.com/dyaa/tlsc/internal/fileutil" ) -func Renew(caPath string) (*CA, error) { +// Renew re-issues the root CA certificate at caPath with a new validity period, keeping the existing private key. +func Renew(ctx context.Context, caPath string) (*CA, error) { if caPath == "" { caPath = DefaultPath() } - existing, err := Load(caPath) + existing, err := Load(ctx, caPath) if err != nil { return nil, fmt.Errorf("failed to load existing CA: %w", err) } @@ -57,7 +61,7 @@ func Renew(caPath string) (*CA, error) { // Backup old rootCA.pem certPath := filepath.Join(caPath, "rootCA.pem") backupPath := certPath + ".bak" - oldPEM, err := os.ReadFile(certPath) + oldPEM, err := fileutil.ReadLimited(certPath, fileutil.MaxCertFileSize) if err != nil { return nil, fmt.Errorf("failed to read old CA certificate for backup: %w", err) } diff --git a/internal/services/ca/renew_test.go b/internal/services/ca/renew_test.go index 653d864..e959902 100644 --- a/internal/services/ca/renew_test.go +++ b/internal/services/ca/renew_test.go @@ -1,6 +1,7 @@ package ca import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -15,8 +16,9 @@ import ( func TestRenew(t *testing.T) { dir := t.TempDir() + ctx := context.Background() - original, err := Init(dir, domain.DefaultCAInitOptions()) + original, err := Init(ctx, dir, domain.DefaultCAInitOptions()) if err != nil { t.Fatal(err) } @@ -24,7 +26,7 @@ func TestRenew(t *testing.T) { // Small sleep to ensure time difference time.Sleep(10 * time.Millisecond) - renewed, err := Renew(dir) + renewed, err := Renew(ctx, dir) if err != nil { t.Fatal(err) } @@ -62,13 +64,14 @@ func TestRenew(t *testing.T) { func TestRenewPreservesKey(t *testing.T) { dir := t.TempDir() + ctx := context.Background() - original, err := Init(dir, domain.DefaultCAInitOptions()) + original, err := Init(ctx, dir, domain.DefaultCAInitOptions()) if err != nil { t.Fatal(err) } - renewed, err := Renew(dir) + renewed, err := Renew(ctx, dir) if err != nil { t.Fatal(err) } @@ -111,8 +114,9 @@ func publicKeysEqual(a, b crypto.PublicKey) bool { func TestRenewBackup(t *testing.T) { dir := t.TempDir() + ctx := context.Background() - original, err := Init(dir, domain.DefaultCAInitOptions()) + original, err := Init(ctx, dir, domain.DefaultCAInitOptions()) if err != nil { t.Fatal(err) } @@ -120,7 +124,7 @@ func TestRenewBackup(t *testing.T) { originalPEM := make([]byte, len(original.CertPEM)) copy(originalPEM, original.CertPEM) - if _, err := Renew(dir); err != nil { + if _, err := Renew(ctx, dir); err != nil { t.Fatal(err) } @@ -149,7 +153,7 @@ func TestRenewBackup(t *testing.T) { func TestRenewNotFound(t *testing.T) { dir := t.TempDir() - _, err := Renew(dir) + _, err := Renew(context.Background(), dir) if err == nil { t.Error("expected error when renewing non-existent CA") } diff --git a/internal/services/checker/checker.go b/internal/services/checker/checker.go index aff23cb..d9c12e3 100644 --- a/internal/services/checker/checker.go +++ b/internal/services/checker/checker.go @@ -1,6 +1,7 @@ package checker import ( + "context" stdtls "crypto/tls" "sync" @@ -15,7 +16,7 @@ func New() *Service { return &Service{} } -func (s *Service) Check(host string, opts domain.CheckOptions) (*domain.Result, error) { +func (s *Service) Check(ctx context.Context, host string, opts domain.CheckOptions) (*domain.Result, error) { port := opts.EffectivePort() serverName := opts.EffectiveServerName(host) @@ -23,9 +24,9 @@ func (s *Service) Check(host string, opts domain.CheckOptions) (*domain.Result, var err error if opts.Protocol != domain.HTTPS { - tlsConn, err = tlsadapter.DialStartTLS(host, port, opts.Protocol, serverName, opts.Timeout) + tlsConn, err = tlsadapter.DialStartTLS(ctx, host, port, opts.Protocol, serverName, opts.Timeout) } else { - tlsConn, err = tlsadapter.Dial(host, port, serverName, opts.Timeout) + tlsConn, err = tlsadapter.Dial(ctx, host, port, serverName, opts.Timeout) } if err != nil { return nil, err @@ -40,16 +41,16 @@ func (s *Service) Check(host string, opts domain.CheckOptions) (*domain.Result, } if opts.Protocol == domain.HTTPS { - result.HSTS = httpadapter.FetchHSTS(host, port, opts.Timeout, serverName) + result.HSTS = httpadapter.FetchHSTS(ctx, host, port, opts.Timeout, serverName) } leaf := state.PeerCertificates[0] if len(state.PeerCertificates) > 1 { - result.OCSP = tlsadapter.CheckOCSP(leaf, state.PeerCertificates[1], state.OCSPResponse, opts.Timeout) - result.CRL = tlsadapter.CheckCRL(leaf, state.PeerCertificates[1], opts.Timeout) + result.OCSP = tlsadapter.CheckOCSP(ctx, leaf, state.PeerCertificates[1], state.OCSPResponse, opts.Timeout) + result.CRL = tlsadapter.CheckCRL(ctx, leaf, state.PeerCertificates[1], opts.Timeout) } else { - result.OCSP = tlsadapter.CheckOCSP(leaf, nil, state.OCSPResponse, opts.Timeout) + result.OCSP = tlsadapter.CheckOCSP(ctx, leaf, nil, state.OCSPResponse, opts.Timeout) } result.SCT = tlsadapter.ExtractSCT(state) @@ -61,28 +62,44 @@ func (s *Service) Check(host string, opts domain.CheckOptions) (*domain.Result, } if opts.DoGrade { - grade := ComputeGrade(host, port, opts.Timeout, result, serverName) + grade := ComputeGrade(ctx, host, port, opts.Timeout, result, serverName) result.Grade = &grade } return result, nil } -func (s *Service) CheckBatch(hosts []string, opts domain.CheckOptions) map[string]*domain.ResultOrError { +func (s *Service) CheckBatch(ctx context.Context, hosts []string, opts domain.CheckOptions) map[string]*domain.ResultOrError { const maxConcurrent = 20 + seen := make(map[string]struct{}, len(hosts)) + unique := make([]string, 0, len(hosts)) + for _, h := range hosts { + if _, ok := seen[h]; !ok { + seen[h] = struct{}{} + unique = append(unique, h) + } + } + sem := make(chan struct{}, maxConcurrent) var mu sync.Mutex var wg sync.WaitGroup - results := make(map[string]*domain.ResultOrError, len(hosts)) + results := make(map[string]*domain.ResultOrError, len(unique)) - for _, h := range hosts { + for _, h := range unique { wg.Add(1) go func(host string) { defer wg.Done() - sem <- struct{}{} - r, err := s.Check(host, opts) + select { + case sem <- struct{}{}: + case <-ctx.Done(): + mu.Lock() + results[host] = &domain.ResultOrError{Err: ctx.Err()} + mu.Unlock() + return + } + r, err := s.Check(ctx, host, opts) <-sem mu.Lock() results[host] = &domain.ResultOrError{Result: r, Err: err} diff --git a/internal/services/checker/checker_test.go b/internal/services/checker/checker_test.go index c47d46a..60ef829 100644 --- a/internal/services/checker/checker_test.go +++ b/internal/services/checker/checker_test.go @@ -1,6 +1,7 @@ package checker import ( + "context" "testing" "time" @@ -12,8 +13,9 @@ func TestCheckValidHost(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() - result, err := svc.Check("github.com", domain.DefaultCheckOptions()) + result, err := svc.Check(ctx, "github.com", domain.DefaultCheckOptions()) if err != nil { t.Fatal(err) } @@ -61,8 +63,9 @@ func TestCheckExpiredHost(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() - result, err := svc.Check("expired.badssl.com", domain.DefaultCheckOptions()) + result, err := svc.Check(ctx, "expired.badssl.com", domain.DefaultCheckOptions()) if err != nil { t.Fatal(err) } @@ -83,8 +86,9 @@ func TestCheckWithHSTS(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() - result, err := svc.Check("github.com", domain.DefaultCheckOptions()) + result, err := svc.Check(ctx, "github.com", domain.DefaultCheckOptions()) if err != nil { t.Fatal(err) } @@ -105,12 +109,13 @@ func TestCheckWarnDays(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() warnDays := 365 opts := domain.DefaultCheckOptions() opts.WarnDays = &warnDays - result, err := svc.Check("github.com", opts) + result, err := svc.Check(ctx, "github.com", opts) if err != nil { t.Fatal(err) } @@ -125,11 +130,12 @@ func TestCheckWithGrade(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() opts := domain.DefaultCheckOptions() opts.DoGrade = true - result, err := svc.Check("github.com", opts) + result, err := svc.Check(ctx, "github.com", opts) if err != nil { t.Fatal(err) } @@ -150,12 +156,13 @@ func TestCheckSMTP(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() opts := domain.DefaultCheckOptions() opts.Protocol = domain.SMTP opts.Timeout = 15 * time.Second - result, err := svc.Check("smtp.gmail.com", opts) + result, err := svc.Check(ctx, "smtp.gmail.com", opts) if err != nil { t.Fatal(err) } @@ -173,8 +180,9 @@ func TestCheckBatch(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() - results := svc.CheckBatch([]string{"github.com", "expired.badssl.com"}, domain.DefaultCheckOptions()) + results := svc.CheckBatch(ctx, []string{"github.com", "expired.badssl.com"}, domain.DefaultCheckOptions()) if len(results) != 2 { t.Errorf("expected 2 results, got %d", len(results)) @@ -202,11 +210,12 @@ func TestCheckUnreachable(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() opts := domain.DefaultCheckOptions() opts.Timeout = 3 * time.Second - _, err := svc.Check("this.does.not.exist.example", opts) + _, err := svc.Check(ctx, "this.does.not.exist.example", opts) if err == nil { t.Error("expected error for unreachable host") } diff --git a/internal/services/checker/grader.go b/internal/services/checker/grader.go index 6bb0af1..37e5a94 100644 --- a/internal/services/checker/grader.go +++ b/internal/services/checker/grader.go @@ -1,6 +1,7 @@ package checker import ( + "context" "crypto/tls" "fmt" "net" @@ -35,7 +36,7 @@ func downgrade(current, to string) string { return current } -func ComputeGrade(host string, port int, timeout time.Duration, result *domain.Result, serverName string) domain.Grade { +func ComputeGrade(ctx context.Context, host string, port int, timeout time.Duration, result *domain.Result, serverName string) domain.Grade { probeTimeout := timeout if probeTimeout > 5*time.Second { probeTimeout = 5 * time.Second @@ -47,21 +48,21 @@ func ComputeGrade(host string, port int, timeout time.Duration, result *domain.R go func() { defer wg.Done() - hasTLS10 = probeVersion(host, port, tls.VersionTLS10, probeTimeout, serverName) + hasTLS10 = probeVersion(ctx, host, port, tls.VersionTLS10, probeTimeout, serverName) }() go func() { defer wg.Done() - hasTLS11 = probeVersion(host, port, tls.VersionTLS11, probeTimeout, serverName) + hasTLS11 = probeVersion(ctx, host, port, tls.VersionTLS11, probeTimeout, serverName) }() go func() { defer wg.Done() - hasTLS12 = probeVersion(host, port, tls.VersionTLS12, probeTimeout, serverName) + hasTLS12 = probeVersion(ctx, host, port, tls.VersionTLS12, probeTimeout, serverName) }() go func() { defer wg.Done() - hasTLS13 = probeVersion(host, port, tls.VersionTLS13, probeTimeout, serverName) + hasTLS13 = probeVersion(ctx, host, port, tls.VersionTLS13, probeTimeout, serverName) }() - go func() { defer wg.Done(); hasWeakCiphers = probeWeak(host, port, probeTimeout, serverName) }() + go func() { defer wg.Done(); hasWeakCiphers = probeWeak(ctx, host, port, probeTimeout, serverName) }() wg.Wait() @@ -139,32 +140,50 @@ func ComputeGrade(host string, port int, timeout time.Duration, result *domain.R return domain.Grade{Grade: grade, Protocols: protocols, WeakCiphers: hasWeakCiphers, Reasons: reasons} } -func probeVersion(host string, port int, version uint16, timeout time.Duration, serverName string) bool { - dialer := &net.Dialer{Timeout: timeout} - conn, err := tls.DialWithDialer(dialer, "tcp", fmt.Sprintf("%s:%d", host, port), &tls.Config{ +func probeVersion(ctx context.Context, host string, port int, version uint16, timeout time.Duration, serverName string) bool { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + dialer := &net.Dialer{} + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return false + } + tlsConn := tls.Client(conn, &tls.Config{ ServerName: serverName, MinVersion: version, MaxVersion: version, InsecureSkipVerify: true, }) - if err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() return false } - conn.Close() + tlsConn.Close() return true } -func probeWeak(host string, port int, timeout time.Duration, serverName string) bool { - dialer := &net.Dialer{Timeout: timeout} - conn, err := tls.DialWithDialer(dialer, "tcp", fmt.Sprintf("%s:%d", host, port), &tls.Config{ +func probeWeak(ctx context.Context, host string, port int, timeout time.Duration, serverName string) bool { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + dialer := &net.Dialer{} + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return false + } + tlsConn := tls.Client(conn, &tls.Config{ ServerName: serverName, CipherSuites: weakCipherSuites, MaxVersion: tls.VersionTLS12, InsecureSkipVerify: true, }) - if err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { + conn.Close() return false } - conn.Close() + tlsConn.Close() return true } diff --git a/internal/services/checker/grader_test.go b/internal/services/checker/grader_test.go index 4f4fb7d..1ad21a4 100644 --- a/internal/services/checker/grader_test.go +++ b/internal/services/checker/grader_test.go @@ -1,6 +1,7 @@ package checker import ( + "context" "testing" "github.com/dyaa/tlsc/internal/domain" @@ -44,7 +45,8 @@ func TestComputeGradeInvalidCert(t *testing.T) { ValidationError: "CERT_HAS_EXPIRED", } - grade := ComputeGrade("example.com", 443, 5e9, result, "example.com") + ctx := context.Background() + grade := ComputeGrade(ctx, "example.com", 443, 5e9, result, "example.com") if grade.Grade != "F" { t.Errorf("expected F for invalid cert, got %s", grade.Grade) } @@ -59,7 +61,8 @@ func TestComputeGradeExpired(t *testing.T) { DaysRemaining: -10, } - grade := ComputeGrade("example.com", 443, 5e9, result, "example.com") + ctx := context.Background() + grade := ComputeGrade(ctx, "example.com", 443, 5e9, result, "example.com") if grade.Grade != "F" { t.Errorf("expected F for expired cert, got %s", grade.Grade) } @@ -70,14 +73,15 @@ func TestComputeGradeLive(t *testing.T) { t.Skip("skipping integration test") } + ctx := context.Background() svc := New() opts := domain.DefaultCheckOptions() - result, err := svc.Check("github.com", opts) + result, err := svc.Check(ctx, "github.com", opts) if err != nil { t.Fatal(err) } - grade := ComputeGrade("github.com", 443, opts.Timeout, result, "github.com") + grade := ComputeGrade(ctx, "github.com", 443, opts.Timeout, result, "github.com") if grade.Grade != "A+" && grade.Grade != "A" { t.Errorf("expected A+ or A for github.com, got %s: %v", grade.Grade, grade.Reasons) } diff --git a/internal/services/checker/inspect.go b/internal/services/checker/inspect.go index b451e8b..d80ef72 100644 --- a/internal/services/checker/inspect.go +++ b/internal/services/checker/inspect.go @@ -1,26 +1,27 @@ package checker import ( + "context" "crypto/x509" "encoding/pem" "fmt" "math" - "os" "time" tlsadapter "github.com/dyaa/tlsc/internal/adapters/tls" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" ) -func (s *Service) InspectFile(path string) (*domain.Result, error) { - data, err := os.ReadFile(path) +func (s *Service) InspectFile(ctx context.Context, path string) (*domain.Result, error) { + data, err := fileutil.ReadLimited(path, fileutil.MaxCertFileSize) if err != nil { - return nil, fmt.Errorf("failed to read file: %w", err) + return nil, &domain.ErrFileRead{Path: path, Err: err} } block, _ := pem.Decode(data) if block == nil { - return nil, fmt.Errorf("not a valid PEM file") + return nil, &domain.ErrInvalidPEM{Path: path} } switch block.Type { diff --git a/internal/services/checker/inspect_test.go b/internal/services/checker/inspect_test.go index a85932b..fb56504 100644 --- a/internal/services/checker/inspect_test.go +++ b/internal/services/checker/inspect_test.go @@ -1,6 +1,7 @@ package checker import ( + "context" "os" "path/filepath" "testing" @@ -13,13 +14,13 @@ import ( func TestInspectFile(t *testing.T) { // Setup: create CA + generate a cert caDir := filepath.Join(t.TempDir(), "ca") - if _, err := ca.Init(caDir, domain.DefaultCAInitOptions()); err != nil { + if _, err := ca.Init(context.Background(),caDir, domain.DefaultCAInitOptions()); err != nil { t.Fatal(err) } outDir := t.TempDir() genSvc := generate.New() - genResult, err := genSvc.Generate([]string{"inspect.test", "192.168.1.100"}, domain.GenerateOptions{ + genResult, err := genSvc.Generate(context.Background(),[]string{"inspect.test", "192.168.1.100"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -31,7 +32,7 @@ func TestInspectFile(t *testing.T) { // Test inspect svc := New() - result, err := svc.InspectFile(genResult.CertPath) + result, err := svc.InspectFile(context.Background(),genResult.CertPath) if err != nil { t.Fatal(err) } @@ -67,7 +68,7 @@ func TestInspectFile(t *testing.T) { func TestInspectFileNotFound(t *testing.T) { svc := New() - _, err := svc.InspectFile("/nonexistent/path.pem") + _, err := svc.InspectFile(context.Background(),"/nonexistent/path.pem") if err == nil { t.Error("expected error for missing file") } @@ -80,7 +81,7 @@ func TestInspectFileInvalidPEM(t *testing.T) { } svc := New() - _, err := svc.InspectFile(f) + _, err := svc.InspectFile(context.Background(),f) if err == nil { t.Error("expected error for invalid PEM") } @@ -90,7 +91,7 @@ func TestInspectCSR(t *testing.T) { outDir := t.TempDir() genSvc := generate.New() - csrResult, err := genSvc.GenerateCSR([]string{"csr.test", "10.0.0.5"}, domain.GenerateOptions{ + csrResult, err := genSvc.GenerateCSR(context.Background(),[]string{"csr.test", "10.0.0.5"}, domain.GenerateOptions{ KeyType: "ecdsa", OutDir: outDir, }) @@ -99,7 +100,7 @@ func TestInspectCSR(t *testing.T) { } svc := New() - result, err := svc.InspectFile(csrResult.CSRPath) + result, err := svc.InspectFile(context.Background(),csrResult.CSRPath) if err != nil { t.Fatal(err) } @@ -139,7 +140,7 @@ func TestInspectUnsupportedPEMType(t *testing.T) { } svc := New() - _, err := svc.InspectFile(f) + _, err := svc.InspectFile(context.Background(),f) if err == nil { t.Error("expected error for unsupported PEM type") } diff --git a/internal/services/checker/list.go b/internal/services/checker/list.go index 61505eb..fc87696 100644 --- a/internal/services/checker/list.go +++ b/internal/services/checker/list.go @@ -1,6 +1,7 @@ package checker import ( + "context" "crypto/x509" "encoding/pem" "math" @@ -11,9 +12,10 @@ import ( "time" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" ) -func (s *Service) ListCerts(certsDir string) ([]domain.CertSummary, error) { +func (s *Service) ListCerts(ctx context.Context, certsDir string) ([]domain.CertSummary, error) { if certsDir == "" { certsDir = domain.DefaultCertsPath() } @@ -39,7 +41,7 @@ func (s *Service) ListCerts(certsDir string) ([]domain.CertSummary, error) { fullPath := filepath.Join(certsDir, name) - data, err := os.ReadFile(fullPath) + data, err := fileutil.ReadLimited(fullPath, fileutil.MaxCertFileSize) if err != nil { continue } diff --git a/internal/services/checker/list_test.go b/internal/services/checker/list_test.go index 824223a..15aea32 100644 --- a/internal/services/checker/list_test.go +++ b/internal/services/checker/list_test.go @@ -1,6 +1,7 @@ package checker import ( + "context" "path/filepath" "testing" @@ -10,8 +11,9 @@ import ( ) func TestListCerts(t *testing.T) { + ctx := context.Background() caDir := filepath.Join(t.TempDir(), "ca") - if _, err := ca.Init(caDir, domain.DefaultCAInitOptions()); err != nil { + if _, err := ca.Init(ctx, caDir, domain.DefaultCAInitOptions()); err != nil { t.Fatal(err) } @@ -19,7 +21,7 @@ func TestListCerts(t *testing.T) { genSvc := generate.New() // Generate first cert with 365 days - _, err := genSvc.Generate([]string{"alpha.test", "10.0.0.1"}, domain.GenerateOptions{ + _, err := genSvc.Generate(ctx, []string{"alpha.test", "10.0.0.1"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -30,7 +32,7 @@ func TestListCerts(t *testing.T) { } // Generate second cert with 30 days (expires sooner) - _, err = genSvc.Generate([]string{"beta.test"}, domain.GenerateOptions{ + _, err = genSvc.Generate(ctx, []string{"beta.test"}, domain.GenerateOptions{ Days: 30, KeyType: "ecdsa", OutDir: outDir, @@ -41,7 +43,7 @@ func TestListCerts(t *testing.T) { } svc := New() - certs, err := svc.ListCerts(outDir) + certs, err := svc.ListCerts(ctx, outDir) if err != nil { t.Fatal(err) } @@ -82,7 +84,7 @@ func TestListCerts(t *testing.T) { func TestListCertsEmptyDir(t *testing.T) { svc := New() - certs, err := svc.ListCerts(t.TempDir()) + certs, err := svc.ListCerts(context.Background(), t.TempDir()) if err != nil { t.Fatal(err) } @@ -93,7 +95,7 @@ func TestListCertsEmptyDir(t *testing.T) { func TestListCertsNonexistentDir(t *testing.T) { svc := New() - _, err := svc.ListCerts(filepath.Join(t.TempDir(), "nonexistent")) + _, err := svc.ListCerts(context.Background(), filepath.Join(t.TempDir(), "nonexistent")) if err == nil { t.Error("expected error for nonexistent directory") } diff --git a/internal/services/checker/verify.go b/internal/services/checker/verify.go index 76ae1ee..c8fdbea 100644 --- a/internal/services/checker/verify.go +++ b/internal/services/checker/verify.go @@ -1,26 +1,27 @@ package checker import ( + "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" - "os" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" ) -func (s *Service) VerifyFile(certPath string, opts domain.VerifyOptions) (*domain.VerifyResult, error) { - certPEM, err := os.ReadFile(certPath) +func (s *Service) VerifyFile(ctx context.Context, certPath string, opts domain.VerifyOptions) (*domain.VerifyResult, error) { + certPEM, err := fileutil.ReadLimited(certPath, fileutil.MaxCertFileSize) if err != nil { - return nil, fmt.Errorf("failed to read certificate: %w", err) + return nil, &domain.ErrFileRead{Path: certPath, Err: err} } block, _ := pem.Decode(certPEM) if block == nil { - return nil, fmt.Errorf("not a valid PEM file") + return nil, &domain.ErrInvalidPEM{Path: certPath} } cert, err := x509.ParseCertificate(block.Bytes) @@ -56,7 +57,7 @@ func (s *Service) VerifyFile(certPath string, opts domain.VerifyOptions) (*domai } func verifyChain(cert *x509.Certificate, caPath string) (bool, []string) { - caPEM, err := os.ReadFile(caPath) + caPEM, err := fileutil.ReadLimited(caPath, fileutil.MaxCertFileSize) if err != nil { return false, []string{fmt.Sprintf("failed to read CA certificate: %v", err)} } @@ -85,7 +86,7 @@ func verifyChain(cert *x509.Certificate, caPath string) (bool, []string) { } func verifyKeyMatch(cert *x509.Certificate, keyPath string) (bool, []string) { - keyPEM, err := os.ReadFile(keyPath) + keyPEM, err := fileutil.ReadLimited(keyPath, fileutil.MaxCertFileSize) if err != nil { return false, []string{fmt.Sprintf("failed to read key: %v", err)} } diff --git a/internal/services/checker/verify_test.go b/internal/services/checker/verify_test.go index ee32472..4437b8d 100644 --- a/internal/services/checker/verify_test.go +++ b/internal/services/checker/verify_test.go @@ -1,6 +1,7 @@ package checker import ( + "context" "path/filepath" "testing" @@ -13,7 +14,7 @@ func setupVerifyCA(t *testing.T) string { t.Helper() dir := t.TempDir() caDir := filepath.Join(dir, "ca") - if _, err := ca.Init(caDir, domain.DefaultCAInitOptions()); err != nil { + if _, err := ca.Init(context.Background(),caDir, domain.DefaultCAInitOptions()); err != nil { t.Fatal(err) } return caDir @@ -23,7 +24,7 @@ func generateCert(t *testing.T, caDir string, host string) *domain.GenerateResul t.Helper() outDir := t.TempDir() genSvc := generate.New() - result, err := genSvc.Generate([]string{host}, domain.GenerateOptions{ + result, err := genSvc.Generate(context.Background(),[]string{host}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -40,7 +41,7 @@ func TestVerifyAgainstCA(t *testing.T) { gen := generateCert(t, caDir, "verify.test") svc := New() - result, err := svc.VerifyFile(gen.CertPath, domain.VerifyOptions{ + result, err := svc.VerifyFile(context.Background(),gen.CertPath, domain.VerifyOptions{ CAPath: filepath.Join(caDir, "rootCA.pem"), }) if err != nil { @@ -66,7 +67,7 @@ func TestVerifyKeyMatch(t *testing.T) { gen := generateCert(t, caDir, "keymatch.test") svc := New() - result, err := svc.VerifyFile(gen.CertPath, domain.VerifyOptions{ + result, err := svc.VerifyFile(context.Background(),gen.CertPath, domain.VerifyOptions{ CAPath: filepath.Join(caDir, "rootCA.pem"), KeyPath: gen.KeyPath, }) @@ -92,7 +93,7 @@ func TestVerifyKeyMismatch(t *testing.T) { svc := New() // Use cert1's certificate with cert2's key -- they should not match - result, err := svc.VerifyFile(gen1.CertPath, domain.VerifyOptions{ + result, err := svc.VerifyFile(context.Background(),gen1.CertPath, domain.VerifyOptions{ CAPath: filepath.Join(caDir, "rootCA.pem"), KeyPath: gen2.KeyPath, }) @@ -118,7 +119,7 @@ func TestVerifyInvalidCA(t *testing.T) { svc := New() // Verify cert signed by CA1 against CA2 -- should fail - result, err := svc.VerifyFile(gen.CertPath, domain.VerifyOptions{ + result, err := svc.VerifyFile(context.Background(),gen.CertPath, domain.VerifyOptions{ CAPath: filepath.Join(caDir2, "rootCA.pem"), }) if err != nil { @@ -138,7 +139,7 @@ func TestVerifyInvalidCA(t *testing.T) { func TestVerifyFileNotFound(t *testing.T) { svc := New() - _, err := svc.VerifyFile("/nonexistent/cert.pem", domain.VerifyOptions{}) + _, err := svc.VerifyFile(context.Background(),"/nonexistent/cert.pem", domain.VerifyOptions{}) if err == nil { t.Error("expected error for missing file") } diff --git a/internal/services/convert/convert.go b/internal/services/convert/convert.go index 85ad804..76baa6f 100644 --- a/internal/services/convert/convert.go +++ b/internal/services/convert/convert.go @@ -1,22 +1,26 @@ package convert import ( + "context" "crypto/x509" "encoding/pem" "fmt" "os" "path/filepath" "strings" + + "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" ) type Service struct{} func New() *Service { return &Service{} } -func (s *Service) Convert(inputPath, outputPath, format string) error { - data, err := os.ReadFile(inputPath) +func (s *Service) Convert(ctx context.Context, inputPath, outputPath, format string) error { + data, err := fileutil.ReadLimited(inputPath, fileutil.MaxCertFileSize) if err != nil { - return fmt.Errorf("failed to read input file: %w", err) + return &domain.ErrFileRead{Path: inputPath, Err: err} } block, _ := pem.Decode(data) diff --git a/internal/services/convert/convert_test.go b/internal/services/convert/convert_test.go index bff00be..9da687d 100644 --- a/internal/services/convert/convert_test.go +++ b/internal/services/convert/convert_test.go @@ -1,6 +1,7 @@ package convert import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -66,7 +67,7 @@ func TestConvertPEMtoDER(t *testing.T) { t.Fatal(err) } - if err := svc.Convert(inputPath, outputPath, "der"); err != nil { + if err := svc.Convert(context.Background(),inputPath, outputPath, "der"); err != nil { t.Fatal(err) } @@ -96,7 +97,7 @@ func TestConvertDERtoPEM(t *testing.T) { t.Fatal(err) } - if err := svc.Convert(inputPath, outputPath, "pem"); err != nil { + if err := svc.Convert(context.Background(),inputPath, outputPath, "pem"); err != nil { t.Fatal(err) } @@ -136,7 +137,7 @@ func TestConvertAutoDetect(t *testing.T) { } // Empty format should auto-detect from output extension - if err := svc.Convert(inputPath, outputPath, ""); err != nil { + if err := svc.Convert(context.Background(),inputPath, outputPath, ""); err != nil { t.Fatal(err) } @@ -157,7 +158,7 @@ func TestConvertAutoDetect(t *testing.T) { t.Fatal(err) } - if err := svc.Convert(derInputPath, pemOutputPath, ""); err != nil { + if err := svc.Convert(context.Background(),derInputPath, pemOutputPath, ""); err != nil { t.Fatal(err) } @@ -183,13 +184,13 @@ func TestConvertInvalidInput(t *testing.T) { t.Fatal(err) } - err := svc.Convert(inputPath, outputPath, "pem") + err := svc.Convert(context.Background(),inputPath, outputPath, "pem") if err == nil { t.Error("expected error for garbage input") } // Non-existent file - err = svc.Convert(filepath.Join(dir, "nonexistent.pem"), outputPath, "der") + err = svc.Convert(context.Background(),filepath.Join(dir, "nonexistent.pem"), outputPath, "der") if err == nil { t.Error("expected error for non-existent file") } @@ -210,12 +211,12 @@ func TestConvertRoundTrip(t *testing.T) { } // PEM -> DER - if err := svc.Convert(pemPath, derPath, "der"); err != nil { + if err := svc.Convert(context.Background(),pemPath, derPath, "der"); err != nil { t.Fatalf("PEM->DER failed: %v", err) } // DER -> PEM - if err := svc.Convert(derPath, roundTripPath, "pem"); err != nil { + if err := svc.Convert(context.Background(),derPath, roundTripPath, "pem"); err != nil { t.Fatalf("DER->PEM failed: %v", err) } diff --git a/internal/services/generate/csr.go b/internal/services/generate/csr.go index 6004669..db4d8a7 100644 --- a/internal/services/generate/csr.go +++ b/internal/services/generate/csr.go @@ -1,6 +1,7 @@ package generate import ( + "context" "crypto/rand" "crypto/x509" "crypto/x509/pkix" @@ -14,10 +15,15 @@ import ( tlscrypto "github.com/dyaa/tlsc/internal/crypto" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" "github.com/dyaa/tlsc/internal/services/ca" ) -func (s *Service) GenerateCSR(hosts []string, opts domain.GenerateOptions) (*domain.CSRResult, error) { +func (s *Service) GenerateCSR(ctx context.Context, hosts []string, opts domain.GenerateOptions) (*domain.CSRResult, error) { + if len(hosts) == 0 { + return nil, fmt.Errorf("at least one host is required") + } + key, err := tlscrypto.GenerateKey(opts.KeyType) if err != nil { return nil, err @@ -84,15 +90,15 @@ func (s *Service) GenerateCSR(hosts []string, opts domain.GenerateOptions) (*dom }, nil } -func (s *Service) SignCSR(csrPath string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { - csrPEM, err := os.ReadFile(csrPath) +func (s *Service) SignCSR(ctx context.Context, csrPath string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { + csrPEM, err := fileutil.ReadLimited(csrPath, fileutil.MaxCertFileSize) if err != nil { - return nil, fmt.Errorf("failed to read CSR file: %w", err) + return nil, &domain.ErrFileRead{Path: csrPath, Err: err} } block, _ := pem.Decode(csrPEM) if block == nil { - return nil, fmt.Errorf("failed to decode CSR PEM") + return nil, &domain.ErrInvalidPEM{Path: csrPath} } csr, err := x509.ParseCertificateRequest(block.Bytes) @@ -109,7 +115,7 @@ func (s *Service) SignCSR(csrPath string, opts domain.GenerateOptions) (*domain. caPath = ca.DefaultPath() } - authority, err := ca.Load(caPath) + authority, err := ca.Load(ctx, caPath) if err != nil { return nil, err } diff --git a/internal/services/generate/csr_test.go b/internal/services/generate/csr_test.go index c07b298..89ee640 100644 --- a/internal/services/generate/csr_test.go +++ b/internal/services/generate/csr_test.go @@ -1,6 +1,7 @@ package generate import ( + "context" "crypto/x509" "encoding/pem" "os" @@ -14,7 +15,7 @@ func TestGenerateCSR(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.GenerateCSR([]string{"test.local"}, domain.GenerateOptions{ + result, err := svc.GenerateCSR(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ KeyType: "ecdsa", OutDir: outDir, }) @@ -64,7 +65,7 @@ func TestSignCSR(t *testing.T) { svc := New() // Generate a CSR first - csrResult, err := svc.GenerateCSR([]string{"sign-test.local", "192.168.1.10"}, domain.GenerateOptions{ + csrResult, err := svc.GenerateCSR(context.Background(),[]string{"sign-test.local", "192.168.1.10"}, domain.GenerateOptions{ KeyType: "ecdsa", OutDir: outDir, }) @@ -73,7 +74,7 @@ func TestSignCSR(t *testing.T) { } // Sign the CSR - certResult, err := svc.SignCSR(csrResult.CSRPath, domain.GenerateOptions{ + certResult, err := svc.SignCSR(context.Background(),csrResult.CSRPath, domain.GenerateOptions{ Days: 365, OutDir: outDir, CAPath: caDir, @@ -134,7 +135,7 @@ func TestSignCSRWithSubjectFields(t *testing.T) { svc := New() // Generate CSR with subject fields - csrResult, err := svc.GenerateCSR([]string{"org.local"}, domain.GenerateOptions{ + csrResult, err := svc.GenerateCSR(context.Background(),[]string{"org.local"}, domain.GenerateOptions{ KeyType: "ecdsa", OutDir: outDir, Organization: "TestOrg", @@ -148,7 +149,7 @@ func TestSignCSRWithSubjectFields(t *testing.T) { } // Sign the CSR - certResult, err := svc.SignCSR(csrResult.CSRPath, domain.GenerateOptions{ + certResult, err := svc.SignCSR(context.Background(),csrResult.CSRPath, domain.GenerateOptions{ Days: 365, OutDir: outDir, CAPath: caDir, @@ -193,7 +194,7 @@ func TestSignCSRFileNotFound(t *testing.T) { caDir := setupCA(t) svc := New() - _, err := svc.SignCSR(filepath.Join(t.TempDir(), "nonexistent.csr"), domain.GenerateOptions{ + _, err := svc.SignCSR(context.Background(),filepath.Join(t.TempDir(), "nonexistent.csr"), domain.GenerateOptions{ Days: 365, OutDir: t.TempDir(), CAPath: caDir, @@ -209,7 +210,7 @@ func TestGenerateCSRKeyTypes(t *testing.T) { t.Run(kt, func(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.GenerateCSR([]string{"test.local"}, domain.GenerateOptions{ + result, err := svc.GenerateCSR(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ KeyType: kt, OutDir: outDir, }) diff --git a/internal/services/generate/generate.go b/internal/services/generate/generate.go index 59c6010..398bf1a 100644 --- a/internal/services/generate/generate.go +++ b/internal/services/generate/generate.go @@ -1,6 +1,7 @@ package generate import ( + "context" "crypto/rand" "crypto/x509" "crypto/x509/pkix" @@ -15,6 +16,7 @@ import ( tlscrypto "github.com/dyaa/tlsc/internal/crypto" "github.com/dyaa/tlsc/internal/domain" + "github.com/dyaa/tlsc/internal/fileutil" "github.com/dyaa/tlsc/internal/services/ca" ) @@ -24,13 +26,17 @@ func New() *Service { return &Service{} } -func (s *Service) Generate(hosts []string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { +func (s *Service) Generate(ctx context.Context, hosts []string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { + if len(hosts) == 0 { + return nil, fmt.Errorf("at least one host is required") + } + caPath := opts.CAPath if caPath == "" { caPath = ca.DefaultPath() } - authority, err := ca.Load(caPath) + authority, err := ca.Load(ctx, caPath) if err != nil { return nil, err } @@ -136,17 +142,17 @@ func (s *Service) Generate(hosts []string, opts domain.GenerateOptions) (*domain }, nil } -func (s *Service) Renew(certPath string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { +func (s *Service) Renew(ctx context.Context, certPath string, opts domain.GenerateOptions) (*domain.GenerateResult, error) { certPath = resolveCertPath(certPath) - certData, err := os.ReadFile(certPath) + certData, err := fileutil.ReadLimited(certPath, fileutil.MaxCertFileSize) if err != nil { - return nil, fmt.Errorf("failed to read certificate: %w", err) + return nil, &domain.ErrFileRead{Path: certPath, Err: err} } block, _ := pem.Decode(certData) if block == nil { - return nil, fmt.Errorf("not a valid PEM file") + return nil, &domain.ErrInvalidPEM{Path: certPath} } oldCert, err := x509.ParseCertificate(block.Bytes) @@ -193,7 +199,7 @@ func (s *Service) Renew(certPath string, opts domain.GenerateOptions) (*domain.G opts.OutDir = filepath.Dir(certPath) - return s.Generate(hosts, opts) + return s.Generate(ctx, hosts, opts) } func resolveCertPath(input string) string { @@ -210,7 +216,8 @@ func resolveCertPath(input string) string { } func sanitizeName(host string) string { - name := strings.ReplaceAll(host, "*.", "wildcard.") + name := strings.ReplaceAll(host, "\x00", "") + name = strings.ReplaceAll(name, "*.", "wildcard.") name = strings.ReplaceAll(name, "@", "_at_") name = strings.ReplaceAll(name, "..", "_") name = strings.ReplaceAll(name, "/", "_") diff --git a/internal/services/generate/generate_test.go b/internal/services/generate/generate_test.go index 225921a..16e903c 100644 --- a/internal/services/generate/generate_test.go +++ b/internal/services/generate/generate_test.go @@ -1,6 +1,7 @@ package generate import ( + "context" "crypto/x509" "encoding/pem" "os" @@ -15,7 +16,7 @@ func setupCA(t *testing.T) string { t.Helper() dir := t.TempDir() caDir := filepath.Join(dir, "ca") - if _, err := ca.Init(caDir, domain.DefaultCAInitOptions()); err != nil { + if _, err := ca.Init(context.Background(), caDir, domain.DefaultCAInitOptions()); err != nil { t.Fatal(err) } return caDir @@ -26,7 +27,7 @@ func TestGenerateBasic(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"test.local"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -69,7 +70,7 @@ func TestGenerateWildcard(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"*.home.lab", "home.lab"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"*.home.lab", "home.lab"}, domain.GenerateOptions{ Days: 825, KeyType: "ecdsa", OutDir: outDir, @@ -93,7 +94,7 @@ func TestGenerateWithIP(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"myhost.local", "192.168.1.1"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"myhost.local", "192.168.1.1"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -120,7 +121,7 @@ func TestGenerateClientCert(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"admin@home.lab"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"admin@home.lab"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -154,7 +155,7 @@ func TestGenerateKeyTypes(t *testing.T) { t.Run(kt, func(t *testing.T) { outDir := t.TempDir() svc := New() - _, err := svc.Generate([]string{"test.local"}, domain.GenerateOptions{ + _, err := svc.Generate(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ Days: 30, KeyType: kt, OutDir: outDir, @@ -170,7 +171,7 @@ func TestGenerateKeyTypes(t *testing.T) { func TestGenerateInvalidKeyType(t *testing.T) { caDir := setupCA(t) svc := New() - _, err := svc.Generate([]string{"test.local"}, domain.GenerateOptions{ + _, err := svc.Generate(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ KeyType: "invalid", OutDir: t.TempDir(), CAPath: caDir, @@ -185,7 +186,7 @@ func TestGenerateServerAndClient(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"mtls.local"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"mtls.local"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -223,7 +224,7 @@ func TestGenerateWithSubjectFields(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"org.local"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"org.local"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -265,7 +266,7 @@ func TestGenerateBundle(t *testing.T) { outDir := t.TempDir() svc := New() - result, err := svc.Generate([]string{"bundle.local"}, domain.GenerateOptions{ + result, err := svc.Generate(context.Background(),[]string{"bundle.local"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -307,7 +308,7 @@ func TestRenewCert(t *testing.T) { outDir := t.TempDir() svc := New() - original, err := svc.Generate([]string{"renew.test", "192.168.1.50"}, domain.GenerateOptions{ + original, err := svc.Generate(context.Background(),[]string{"renew.test", "192.168.1.50"}, domain.GenerateOptions{ Days: 365, KeyType: "ecdsa", OutDir: outDir, @@ -322,7 +323,7 @@ func TestRenewCert(t *testing.T) { } // Renew the certificate - renewed, err := svc.Renew(original.CertPath, domain.GenerateOptions{ + renewed, err := svc.Renew(context.Background(),original.CertPath, domain.GenerateOptions{ Days: 730, KeyType: "ecdsa", CAPath: caDir, @@ -376,7 +377,7 @@ func TestRenewCert(t *testing.T) { func TestRenewFileNotFound(t *testing.T) { svc := New() - _, err := svc.Renew("/nonexistent/cert.pem", domain.DefaultGenerateOptions()) + _, err := svc.Renew(context.Background(),"/nonexistent/cert.pem", domain.DefaultGenerateOptions()) if err == nil { t.Error("expected error for missing file") } @@ -384,7 +385,7 @@ func TestRenewFileNotFound(t *testing.T) { func TestGenerateNoCA(t *testing.T) { svc := New() - _, err := svc.Generate([]string{"test.local"}, domain.GenerateOptions{ + _, err := svc.Generate(context.Background(),[]string{"test.local"}, domain.GenerateOptions{ OutDir: t.TempDir(), CAPath: t.TempDir(), }) diff --git a/pkg/tlsc/tlsc.go b/pkg/tlsc/tlsc.go index 9dec112..2295f45 100644 --- a/pkg/tlsc/tlsc.go +++ b/pkg/tlsc/tlsc.go @@ -4,12 +4,19 @@ // Sieve), managing a local Certificate Authority, generating and renewing // certificates, handling CSR workflows, and converting between PEM and DER. // +// Note: Check and CheckBatch connect with InsecureSkipVerify enabled so that +// certificates can be inspected regardless of validity. The returned Result +// contains the post-connection verification outcome — callers should not treat +// a successful Check call as proof that the connection was authenticated. +// // result, _ := tlsc.Check("example.com") // result, _ = tlsc.InspectFile("cert.pem") // gr, _ := tlsc.Generate([]string{"server.lan"}) package tlsc import ( + "context" + "github.com/dyaa/tlsc/internal/domain" "github.com/dyaa/tlsc/internal/services/ca" "github.com/dyaa/tlsc/internal/services/checker" @@ -17,37 +24,99 @@ import ( "github.com/dyaa/tlsc/internal/services/generate" ) +// Result is the full inspection outcome for a single TLS certificate. type Result = domain.Result + +// ResultOrError pairs a Result with an optional error, used in batch operations. type ResultOrError = domain.ResultOrError + +// CheckOptions configures how a TLS certificate check is performed. type CheckOptions = domain.CheckOptions + +// GenerateOptions configures certificate or CSR generation. type GenerateOptions = domain.GenerateOptions + +// GenerateResult holds the output paths and metadata from certificate generation. type GenerateResult = domain.GenerateResult + +// CSRResult holds the output paths from CSR generation. type CSRResult = domain.CSRResult + +// CAInitOptions configures the creation of a new Certificate Authority. type CAInitOptions = domain.CAInitOptions + +// VerifyOptions configures certificate verification against a CA and optional private key. type VerifyOptions = domain.VerifyOptions + +// VerifyResult holds the outcome of a certificate verification check. type VerifyResult = domain.VerifyResult + +// Grade represents the overall TLS security grade and associated findings. type Grade = domain.Grade + +// HSTS holds HTTP Strict Transport Security header information for a host. type HSTS = domain.HSTS + +// Subject holds the distinguished name fields of a certificate subject or issuer. type Subject = domain.Subject + +// Protocol represents a connection protocol used for TLS certificate checks. type Protocol = domain.Protocol +// ErrConnection indicates a TCP-level connection failure. +type ErrConnection = domain.ErrConnection + +// ErrTLSHandshake indicates a failure during the TLS handshake. +type ErrTLSHandshake = domain.ErrTLSHandshake + +// ErrCertificateInvalid indicates that the certificate failed validation. +type ErrCertificateInvalid = domain.ErrCertificateInvalid + +// ErrCANotFound indicates that the Certificate Authority was not found at the given path. +type ErrCANotFound = domain.ErrCANotFound + +// ErrCAExists indicates that a Certificate Authority already exists at the given path. +type ErrCAExists = domain.ErrCAExists + +// ErrFileRead indicates a failure reading a certificate or key file. +type ErrFileRead = domain.ErrFileRead + +// ErrInvalidPEM indicates that PEM decoding failed for the given input. +type ErrInvalidPEM = domain.ErrInvalidPEM + +// ErrUnsupportedProtocol indicates that the requested protocol is not supported. +type ErrUnsupportedProtocol = domain.ErrUnsupportedProtocol + const ( - HTTPS = domain.HTTPS - SMTP = domain.SMTP - IMAP = domain.IMAP - POP3 = domain.POP3 - FTP = domain.FTP - LDAP = domain.LDAP - MySQL = domain.MySQL + // HTTPS is the HTTPS protocol (default port 443). + HTTPS = domain.HTTPS + // SMTP is the SMTP protocol with STARTTLS (default port 587). + SMTP = domain.SMTP + // IMAP is the IMAP protocol with STARTTLS (default port 143). + IMAP = domain.IMAP + // POP3 is the POP3 protocol with STARTTLS (default port 110). + POP3 = domain.POP3 + // FTP is the FTP protocol with explicit TLS (default port 21). + FTP = domain.FTP + // LDAP is the LDAP protocol with STARTTLS (default port 389). + LDAP = domain.LDAP + // MySQL is the MySQL protocol with TLS negotiation (default port 3306). + MySQL = domain.MySQL + // Postgres is the PostgreSQL protocol with TLS negotiation (default port 5432). Postgres = domain.Postgres - XMPP = domain.XMPP - Sieve = domain.Sieve + // XMPP is the XMPP protocol with STARTTLS (default port 5222). + XMPP = domain.XMPP + // Sieve is the ManageSieve protocol with STARTTLS (default port 4190). + Sieve = domain.Sieve ) var ( - DefaultCheckOptions = domain.DefaultCheckOptions + // DefaultCheckOptions returns CheckOptions with sensible defaults. + DefaultCheckOptions = domain.DefaultCheckOptions + // DefaultGenerateOptions returns GenerateOptions with sensible defaults. DefaultGenerateOptions = domain.DefaultGenerateOptions - DefaultCAInitOptions = domain.DefaultCAInitOptions + // DefaultCAInitOptions returns CAInitOptions with sensible defaults. + DefaultCAInitOptions = domain.DefaultCAInitOptions ) var ( @@ -56,70 +125,81 @@ var ( convertSvc = convert.New() ) -func Check(host string, opts ...CheckOptions) (*Result, error) { +// Check connects to host, performs a TLS handshake, and returns the certificate inspection result. +func Check(ctx context.Context, host string, opts ...CheckOptions) (*Result, error) { o := domain.DefaultCheckOptions() if len(opts) > 0 { o = opts[0] } - return checkSvc.Check(host, o) + return checkSvc.Check(ctx, host, o) } -func CheckBatch(hosts []string, opts ...CheckOptions) map[string]*ResultOrError { +// CheckBatch checks multiple hosts concurrently and returns a map of host to result-or-error. +func CheckBatch(ctx context.Context, hosts []string, opts ...CheckOptions) map[string]*ResultOrError { o := domain.DefaultCheckOptions() if len(opts) > 0 { o = opts[0] } - return checkSvc.CheckBatch(hosts, o) + return checkSvc.CheckBatch(ctx, hosts, o) } -func InspectFile(path string) (*Result, error) { - return checkSvc.InspectFile(path) +// InspectFile parses a PEM-encoded certificate file and returns its inspection result. +func InspectFile(ctx context.Context, path string) (*Result, error) { + return checkSvc.InspectFile(ctx, path) } -func Verify(certPath string, opts VerifyOptions) (*VerifyResult, error) { - return checkSvc.VerifyFile(certPath, opts) +// Verify validates a certificate file against a CA and optionally checks key pairing. +func Verify(ctx context.Context, certPath string, opts VerifyOptions) (*VerifyResult, error) { + return checkSvc.VerifyFile(ctx, certPath, opts) } -func Generate(hosts []string, opts ...GenerateOptions) (*GenerateResult, error) { +// Generate creates a new TLS certificate and private key for the given hosts. +func Generate(ctx context.Context, hosts []string, opts ...GenerateOptions) (*GenerateResult, error) { o := domain.DefaultGenerateOptions() if len(opts) > 0 { o = opts[0] } - return genSvc.Generate(hosts, o) + return genSvc.Generate(ctx, hosts, o) } -func GenerateCSR(hosts []string, opts ...GenerateOptions) (*CSRResult, error) { +// GenerateCSR creates a Certificate Signing Request and private key for the given hosts. +func GenerateCSR(ctx context.Context, hosts []string, opts ...GenerateOptions) (*CSRResult, error) { o := domain.DefaultGenerateOptions() if len(opts) > 0 { o = opts[0] } - return genSvc.GenerateCSR(hosts, o) + return genSvc.GenerateCSR(ctx, hosts, o) } -func SignCSR(csrPath string, opts ...GenerateOptions) (*GenerateResult, error) { +// SignCSR signs a PEM-encoded CSR file using the configured CA and returns the issued certificate. +func SignCSR(ctx context.Context, csrPath string, opts ...GenerateOptions) (*GenerateResult, error) { o := domain.DefaultGenerateOptions() if len(opts) > 0 { o = opts[0] } - return genSvc.SignCSR(csrPath, o) + return genSvc.SignCSR(ctx, csrPath, o) } -func InitCA(caPath string, opts ...CAInitOptions) (*ca.CA, error) { +// InitCA creates a new root Certificate Authority at the given path. +func InitCA(ctx context.Context, caPath string, opts ...CAInitOptions) (*ca.CA, error) { o := domain.DefaultCAInitOptions() if len(opts) > 0 { o = opts[0] } - return ca.Init(caPath, o) + return ca.Init(ctx, caPath, o) } -func LoadCA(caPath string) (*ca.CA, error) { - return ca.Load(caPath) +// LoadCA reads an existing Certificate Authority from disk. +func LoadCA(ctx context.Context, caPath string) (*ca.CA, error) { + return ca.Load(ctx, caPath) } -func RenewCA(caPath string) (*ca.CA, error) { - return ca.Renew(caPath) +// RenewCA re-issues the root CA certificate at caPath, keeping the existing private key. +func RenewCA(ctx context.Context, caPath string) (*ca.CA, error) { + return ca.Renew(ctx, caPath) } -func Convert(inputPath, outputPath, format string) error { - return convertSvc.Convert(inputPath, outputPath, format) +// Convert transforms a certificate file between PEM and DER formats. +func Convert(ctx context.Context, inputPath, outputPath, format string) error { + return convertSvc.Convert(ctx, inputPath, outputPath, format) }