Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions cmd/wsh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"

"github.com/jessevdk/go-flags"
"golang.org/x/term"
Expand All @@ -19,6 +21,7 @@ import (

const (
defaultRealm = "wampshell"
nonceSize = 12
procedureInteractive = "wampshell.shell.interactive"
procedureExec = "wampshell.shell.exec"
procedureWebRTCOffer = "wampshell.webrtc.offer"
Expand All @@ -27,35 +30,63 @@ const (
)

func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) error {
const nonceSize = 12

fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return fmt.Errorf("failed to set raw mode: %w", err)
}
defer func() { _ = term.Restore(fd, oldState) }()

firstProgress := true
progressChan := make(chan *xconn.Progress, 32)

readAndEncrypt := func() (*xconn.Progress, error) {
buf := make([]byte, 1024)
n, err := os.Stdin.Read(buf)
sendSize := func() *xconn.Progress {
width, height, err := term.GetSize(fd)
if err != nil {
return nil, fmt.Errorf("read error: %w", err)
return nil
}

ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send)
msg := fmt.Sprintf("SIZE:%d:%d", width, height)
ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305([]byte(msg), keys.Send)
if err != nil {
return nil, fmt.Errorf("encryption error: %w", err)
return nil
}
payload := append(nonce, ciphertext...)
return xconn.NewProgress(payload), nil
return xconn.NewProgress(payload)
}

if p := sendSize(); p != nil {
progressChan <- p
}

go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGWINCH)
for range sigChan {
if p := sendSize(); p != nil {
progressChan <- p
}
}
}()

go func() {
buf := make([]byte, 1024)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
close(progressChan)
return
}
ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.Send)
if err != nil {
fmt.Fprintln(os.Stderr, "encryption error:", err)
continue
}
progressChan <- xconn.NewProgress(append(nonce, ciphertext...))
}
}()

decryptAndWrite := func(encData []byte) error {
if len(encData) < nonceSize {
return fmt.Errorf("invalid payload from server: too short")
return fmt.Errorf("invalid payload from server")
}
plain, err := berncrypt.DecryptChaCha20Poly1305(encData[nonceSize:], encData[:nonceSize], keys.Receive)
if err != nil {
Expand All @@ -67,16 +98,11 @@ func startInteractiveShell(session *xconn.Session, keys *wampshell.KeyPair) erro

call := session.Call(procedureInteractive).
ProgressSender(func(ctx context.Context) *xconn.Progress {
if firstProgress {
firstProgress = false
return xconn.NewProgress()
}
progress, err := readAndEncrypt()
if err != nil {
fmt.Fprintln(os.Stderr, err)
p, ok := <-progressChan
if !ok {
return xconn.NewFinalProgress()
}
return progress
return p
}).
ProgressReceiver(func(result *xconn.InvocationResult) {
if len(result.Args) > 0 {
Expand Down Expand Up @@ -115,7 +141,8 @@ func runCommand(session *xconn.Session, keys *wampshell.KeyPair, args []string)
return fmt.Errorf("output parsing error: %w", err)
}

plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[12:], encryptedOutput[:12], keys.Receive)
plainOutput, err := berncrypt.DecryptChaCha20Poly1305(encryptedOutput[nonceSize:],
encryptedOutput[:nonceSize], keys.Receive)
if err != nil {
return fmt.Errorf("decryption failed: %w", err)
}
Expand Down
58 changes: 41 additions & 17 deletions cmd/wshd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"log"
"math"
"os"
"os/exec"
"os/signal"
Expand Down Expand Up @@ -94,23 +95,16 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu
inv *xconn.Invocation) *xconn.InvocationResult {
return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult {
caller := inv.Caller()
key, ok := e.Key(inv.Caller())

key, ok := e.Key(caller)
if !ok {
return xconn.NewInvocationError("wamp.error.unavailable", "unavailable")
}

p.Lock()
ptmx, ok := p.ptmx[caller]
ptmx, exists := p.ptmx[caller]
p.Unlock()

if !ok {
_, err := p.startPtySession(inv, key.Send)
if err != nil {
return xconn.NewInvocationError("io.xconn.error", err.Error())
}
return xconn.NewInvocationError(xconn.ErrNoResult)
}

if inv.Progress() {
payload, err := inv.ArgBytes(0)
if err != nil {
Expand All @@ -123,28 +117,58 @@ func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) fu
decrypted, err := berncrypt.DecryptChaCha20Poly1305(payload[12:], payload[:12], key.Receive)
if err != nil {
p.Lock()
if storedPtmx, exists := p.ptmx[caller]; exists {
storedPtmx.Close()
if stored, ok := p.ptmx[caller]; ok {
_ = stored.Close()
delete(p.ptmx, caller)
}
p.Unlock()
return xconn.NewInvocationError("io.xconn.error", err.Error())
}

if bytes.HasPrefix(decrypted, []byte("SIZE:")) {
var cols, rows int
n, _ := fmt.Sscanf(string(decrypted), "SIZE:%d:%d", &cols, &rows)
if n == 2 {
if cols < 0 || cols > math.MaxUint16 || rows < 0 || rows > math.MaxUint16 {
return xconn.NewInvocationError("wamp.error.invalid_argument", "invalid size")
}
if !exists {
newPt, err := p.startPtySession(inv, key.Send)
if err != nil {
return xconn.NewInvocationError("io.xconn.error", err.Error())
}
ptmx = newPt
}
winsize := &pty.Winsize{
Cols: uint16(cols), // #nosec G115
Rows: uint16(rows), // #nosec G115
}
_ = pty.Setsize(ptmx, winsize)
}
return xconn.NewInvocationError(xconn.ErrNoResult)
}

if !exists {
newPt, err := p.startPtySession(inv, key.Send)
if err != nil {
return xconn.NewInvocationError("io.xconn.error", err.Error())
}
ptmx = newPt
}

_, err = ptmx.Write(decrypted)
if err != nil {
log.Printf("Failed to write to PTY for caller %d: %v", caller, err)
return xconn.NewInvocationError("io.xconn.error", err.Error())
}
return xconn.NewInvocationError(xconn.ErrNoResult)
}

p.Lock()
delete(p.ptmx, caller)
p.Unlock()
if ok {
ptmx.Close()
if stored, ok := p.ptmx[caller]; ok {
_ = stored.Close()
delete(p.ptmx, caller)
}
p.Unlock()

return xconn.NewInvocationResult()
}
Expand Down
Loading