diff --git a/controlbox.go b/controlbox.go index c10b905..4521e7d 100644 --- a/controlbox.go +++ b/controlbox.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "os" + "path/filepath" "slices" "strconv" "sync" @@ -56,34 +57,54 @@ type controlbox struct { mutex sync.Mutex } -func (h *controlbox) run() { - var err error - var certificate tls.Certificate - - if len(os.Args) == 4 { - certificate, err = tls.LoadX509KeyPair(os.Args[2], os.Args[3]) - if err != nil { - usage() - log.Fatal(err) - } - } else { - certificate, err = cert.CreateCertificate("Demo", "Demo", "DE", "Demo-Unit-01") - if err != nil { - log.Fatal(err) +func loadOrCreateCertificate(dir string) (tls.Certificate, error) { + crtPath := filepath.Join(dir, "cb.crt") + keyPath := filepath.Join(dir, "cb.key") + + // If both files exist → load + if _, err := os.Stat(crtPath); err == nil { + if _, err := os.Stat(keyPath); err == nil { + return tls.LoadX509KeyPair(crtPath, keyPath) } + } - pemdata := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certificate.Certificate[0], - }) - fmt.Println(string(pemdata)) + // Otherwise create new certificate + certTLS, err := cert.CreateCertificate("Demo", "Demo", "DE", "Demo-Unit-01") + if err != nil { + return tls.Certificate{}, err + } - b, err := x509.MarshalECPrivateKey(certificate.PrivateKey.(*ecdsa.PrivateKey)) - if err != nil { - log.Fatal(err) - } - pemdata = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) - fmt.Println(string(pemdata)) + // Write certificate + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certTLS.Certificate[0], + }) + if err := os.WriteFile(crtPath, certPEM, 0644); err != nil { + return tls.Certificate{}, err + } + + // Write private key + privKey := certTLS.PrivateKey.(*ecdsa.PrivateKey) + keyBytes, err := x509.MarshalECPrivateKey(privKey) + if err != nil { + return tls.Certificate{}, err + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + }) + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return tls.Certificate{}, err + } + + return certTLS, nil +} + +func (h *controlbox) run() { + if len(os.Args) < 2 || len(os.Args) > 3 { + fmt.Println("Usage: controlbox [cert-directory]") + os.Exit(1) } port, err := strconv.Atoi(os.Args[1]) @@ -92,6 +113,21 @@ func (h *controlbox) run() { log.Fatal(err) } + certDir := "." + if len(os.Args) == 3 { + certDir = os.Args[2] + } + + certDir, err = filepath.Abs(certDir) + if err != nil { + log.Fatal(err) + } + + certificate, err := loadOrCreateCertificate(certDir) + if err != nil { + log.Fatal(err) + } + vendorCode := "Demo" deviceBrand := "Demo" deviceModel := "ControlBox"