From 949d1595b421aaf13f5153e095649678c2dc6340 Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Sat, 24 Jan 2026 21:09:01 -0800 Subject: [PATCH] Fix potential panic on invalid xml --- internal/format/xml.go | 23 ++++++ internal/format/xml_test.go | 147 ++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 internal/format/xml_test.go diff --git a/internal/format/xml.go b/internal/format/xml.go index bfc6dd8..7b5bdeb 100644 --- a/internal/format/xml.go +++ b/internal/format/xml.go @@ -51,6 +51,14 @@ func FormatXML(buf []byte, w *core.Printer) error { } stack = append(stack, false) case xml.EndElement: + if len(stack) == 0 { + // Malformed XML: more closing tags than opening tags. + // Skip indent adjustment, just write the closing tag. + w.WriteString("\n") + continue + } last := stack[len(stack)-1] stack = stack[:len(stack)-1] @@ -162,6 +170,21 @@ func writeXMLProcInst(p *core.Printer, inst []byte) { // Mostly taken from the Go encoding/xml package in the standard library: // https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/encoding/xml/xml.go;l=1964-1999 func escapeXMLString(p *core.Printer, s string) { + // Fast path: check if string needs escaping. + needsEscape := false + for i := 0; i < len(s); i++ { + c := s[i] + if c == '"' || c == '\'' || c == '&' || c == '<' || c == '>' || + c == '\t' || c == '\n' || c == '\r' || c >= 0x80 { + needsEscape = true + break + } + } + if !needsEscape { + p.WriteString(s) + return + } + var esc string var last int for i := 0; i < len(s); { diff --git a/internal/format/xml_test.go b/internal/format/xml_test.go new file mode 100644 index 0000000..fb0efb9 --- /dev/null +++ b/internal/format/xml_test.go @@ -0,0 +1,147 @@ +package format + +import ( + "strings" + "testing" + + "github.com/ryanfowler/fetch/internal/core" +) + +func TestFormatXML(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "valid simple xml", + input: "text", + wantErr: false, + }, + { + name: "valid nested xml", + input: "text", + wantErr: false, + }, + { + name: "valid xml with attributes", + input: `text`, + wantErr: false, + }, + { + name: "malformed xml extra closing tag at start", + input: "", + wantErr: true, // XML decoder catches this + }, + { + name: "malformed xml extra closing tag at end", + input: "", + wantErr: true, // XML decoder catches this + }, + { + name: "malformed xml multiple extra closing tags", + input: "", + wantErr: true, // XML decoder catches this + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := core.NewHandle(core.ColorOff).Stderr() + err := FormatXML([]byte(tt.input), p) + if (err != nil) != tt.wantErr { + t.Errorf("FormatXML() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestFormatXMLOutput(t *testing.T) { + input := "text" + p := core.NewHandle(core.ColorOff).Stderr() + err := FormatXML([]byte(input), p) + if err != nil { + t.Fatalf("FormatXML() error = %v", err) + } + + output := string(p.Bytes()) + if !strings.Contains(output, "") { + t.Errorf("output should contain , got: %s", output) + } + if !strings.Contains(output, "") { + t.Errorf("output should contain , got: %s", output) + } + if !strings.Contains(output, "text") { + t.Errorf("output should contain text, got: %s", output) + } +} + +func TestEscapeXMLString(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "ascii no escape needed", + input: "hello world", + want: "hello world", + }, + { + name: "with ampersand", + input: "foo & bar", + want: "foo & bar", + }, + { + name: "with less than", + input: "a < b", + want: "a < b", + }, + { + name: "with greater than", + input: "a > b", + want: "a > b", + }, + { + name: "with quotes", + input: `"quoted"`, + want: ""quoted"", + }, + { + name: "with single quotes", + input: "'quoted'", + want: "'quoted'", + }, + { + name: "with tab", + input: "a\tb", + want: "a b", + }, + { + name: "with newline", + input: "a\nb", + want: "a b", + }, + { + name: "with carriage return", + input: "a\rb", + want: "a b", + }, + { + name: "unicode chars", + input: "日本語", + want: "日本語", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := core.NewHandle(core.ColorOff).Stderr() + escapeXMLString(p, tt.input) + got := string(p.Bytes()) + if got != tt.want { + t.Errorf("escapeXMLString() = %q, want %q", got, tt.want) + } + }) + } +}