diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index 4afca96..6768fd7 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -5,7 +5,9 @@ import ( "fmt" "log" "os" + "os/signal" "strings" + "syscall" "github.com/jessevdk/go-flags" "golang.org/x/term" @@ -19,6 +21,7 @@ import ( const ( defaultRealm = "wampshell" + nonceSize = 12 procedureInteractive = "wampshell.shell.interactive" procedureExec = "wampshell.shell.exec" procedureWebRTCOffer = "wampshell.webrtc.offer" @@ -27,8 +30,6 @@ const ( ) func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) error { - const nonceSize = 12 - fd := int(os.Stdin.Fd()) oldState, err := term.MakeRaw(fd) if err != nil { @@ -36,26 +37,56 @@ func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) erro } defer func() { _ = term.Restore(fd, oldState) }() - firstProgress := true + progressChan := make(chan *xconn.Progress, 32) - readAndEncrypt := func() (*xconn.Progress, error) { - buf := make([]byte, 1024) - n, err := os.Stdin.Read(buf) + sendSize := func() *xconn.Progress { + width, height, err := term.GetSize(fd) if err != nil { - return nil, fmt.Errorf("read error: %w", err) + return nil } - - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) + msg := fmt.Sprintf("SIZE:%d:%d", width, height) + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305([]byte(msg), keys.Send) if err != nil { - return nil, fmt.Errorf("encryption error: %w", err) + return nil } payload := append(nonce, ciphertext...) - return xconn.NewProgress(payload), nil + return xconn.NewProgress(payload) + } + + if p := sendSize(); p != nil { + progressChan <- p } + go func() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGWINCH) + for range sigChan { + if p := sendSize(); p != nil { + progressChan <- p + } + } + }() + + go func() { + buf := make([]byte, 1024) + for { + n, err := os.Stdin.Read(buf) + if err != nil { + close(progressChan) + return + } + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send) + if err != nil { + fmt.Fprintln(os.Stderr, "encryption error:", err) + continue + } + progressChan <- xconn.NewProgress(append(nonce, ciphertext...)) + } + }() + decryptAndWrite := func(encData []byte) error { if len(encData) < nonceSize { - return fmt.Errorf("invalid payload from server: too short") + return fmt.Errorf("invalid payload from server") } plain, err := berncrypt.DecryptChaCha20Poly1305(encData[nonceSize:], encData[:nonceSize], keys.Receive) if err != nil { @@ -67,16 +98,11 @@ func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) erro call := session.Call(procedureInteractive). ProgressSender(func(ctx context.Context) *xconn.Progress { - if firstProgress { - firstProgress = false - return xconn.NewProgress() - } - progress, err := readAndEncrypt() - if err != nil { - fmt.Fprintln(os.Stderr, err) + p, ok := <-progressChan + if !ok { return xconn.NewFinalProgress() } - return progress + return p }). ProgressReceiver(func(result *xconn.InvocationResult) { if len(result.Args) > 0 { @@ -115,7 +141,8 @@ func runCommand(session *xconn.Session, keys *wampshell.KeyPair, args []string) return fmt.Errorf("output parsing error: %w", err) } - plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[12:], encryptedOutput[:12], keys.Receive) + plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[nonceSize:], + encryptedOutput[:nonceSize], keys.Receive) if err != nil { return fmt.Errorf("decryption failed: %w", err) } diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index d9a9bec..f6bc6b4 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log" + "math" "os" "os/exec" "os/signal" @@ -94,23 +95,16 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu inv *xconn.Invocation) *xconn.InvocationResult { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { caller := inv.Caller() - key, ok := e.Key(inv.Caller()) + + key, ok := e.Key(caller) if !ok { return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") } p.Lock() - ptmx, ok := p.ptmx[caller] + ptmx, exists := p.ptmx[caller] p.Unlock() - if !ok { - _, err := p.startPtySession(inv, key.Send) - if err != nil { - return xconn.NewInvocationError("io.xconn.error", err.Error()) - } - return xconn.NewInvocationError(xconn.ErrNoResult) - } - if inv.Progress() { payload, err := inv.ArgBytes(0) if err != nil { @@ -123,28 +117,58 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu decrypted, err := berncrypt.DecryptChaCha20Poly1305(payload[12:], payload[:12], key.Receive) if err != nil { p.Lock() - if storedPtmx, exists := p.ptmx[caller]; exists { - storedPtmx.Close() + if stored, ok := p.ptmx[caller]; ok { + _ = stored.Close() delete(p.ptmx, caller) } p.Unlock() return xconn.NewInvocationError("io.xconn.error", err.Error()) } + if bytes.HasPrefix(decrypted, []byte("SIZE:")) { + var cols, rows int + n, _ := fmt.Sscanf(string(decrypted), "SIZE:%d:%d", &cols, &rows) + if n == 2 { + if cols < 0 || cols > math.MaxUint16 || rows < 0 || rows > math.MaxUint16 { + return xconn.NewInvocationError("wamp.error.invalid_argument", "invalid size") + } + if !exists { + newPt, err := p.startPtySession(inv, key.Send) + if err != nil { + return xconn.NewInvocationError("io.xconn.error", err.Error()) + } + ptmx = newPt + } + winsize := &pty.Winsize{ + Cols: uint16(cols), // #nosec G115 + Rows: uint16(rows), // #nosec G115 + } + _ = pty.Setsize(ptmx, winsize) + } + return xconn.NewInvocationError(xconn.ErrNoResult) + } + + if !exists { + newPt, err := p.startPtySession(inv, key.Send) + if err != nil { + return xconn.NewInvocationError("io.xconn.error", err.Error()) + } + ptmx = newPt + } + _, err = ptmx.Write(decrypted) if err != nil { - log.Printf("Failed to write to PTY for caller %d: %v", caller, err) return xconn.NewInvocationError("io.xconn.error", err.Error()) } return xconn.NewInvocationError(xconn.ErrNoResult) } p.Lock() - delete(p.ptmx, caller) - p.Unlock() - if ok { - ptmx.Close() + if stored, ok := p.ptmx[caller]; ok { + _ = stored.Close() + delete(p.ptmx, caller) } + p.Unlock() return xconn.NewInvocationResult() }