diff --git a/internal/format/xml.go b/internal/format/xml.go
index bfc6dd8..7b5bdeb 100644
--- a/internal/format/xml.go
+++ b/internal/format/xml.go
@@ -51,6 +51,14 @@ func FormatXML(buf []byte, w *core.Printer) error {
}
stack = append(stack, false)
case xml.EndElement:
+ if len(stack) == 0 {
+ // Malformed XML: more closing tags than opening tags.
+ // Skip indent adjustment, just write the closing tag.
+ w.WriteString("")
+ writeXMLTagName(w, t.Name.Local)
+ w.WriteString(">\n")
+ continue
+ }
last := stack[len(stack)-1]
stack = stack[:len(stack)-1]
@@ -162,6 +170,21 @@ func writeXMLProcInst(p *core.Printer, inst []byte) {
// Mostly taken from the Go encoding/xml package in the standard library:
// https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/encoding/xml/xml.go;l=1964-1999
func escapeXMLString(p *core.Printer, s string) {
+ // Fast path: check if string needs escaping.
+ needsEscape := false
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if c == '"' || c == '\'' || c == '&' || c == '<' || c == '>' ||
+ c == '\t' || c == '\n' || c == '\r' || c >= 0x80 {
+ needsEscape = true
+ break
+ }
+ }
+ if !needsEscape {
+ p.WriteString(s)
+ return
+ }
+
var esc string
var last int
for i := 0; i < len(s); {
diff --git a/internal/format/xml_test.go b/internal/format/xml_test.go
new file mode 100644
index 0000000..fb0efb9
--- /dev/null
+++ b/internal/format/xml_test.go
@@ -0,0 +1,147 @@
+package format
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/ryanfowler/fetch/internal/core"
+)
+
+func TestFormatXML(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantErr bool
+ }{
+ {
+ name: "valid simple xml",
+ input: "text",
+ wantErr: false,
+ },
+ {
+ name: "valid nested xml",
+ input: "text",
+ wantErr: false,
+ },
+ {
+ name: "valid xml with attributes",
+ input: `text`,
+ wantErr: false,
+ },
+ {
+ name: "malformed xml extra closing tag at start",
+ input: "",
+ wantErr: true, // XML decoder catches this
+ },
+ {
+ name: "malformed xml extra closing tag at end",
+ input: "",
+ wantErr: true, // XML decoder catches this
+ },
+ {
+ name: "malformed xml multiple extra closing tags",
+ input: "",
+ wantErr: true, // XML decoder catches this
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := core.NewHandle(core.ColorOff).Stderr()
+ err := FormatXML([]byte(tt.input), p)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("FormatXML() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestFormatXMLOutput(t *testing.T) {
+ input := "text"
+ p := core.NewHandle(core.ColorOff).Stderr()
+ err := FormatXML([]byte(input), p)
+ if err != nil {
+ t.Fatalf("FormatXML() error = %v", err)
+ }
+
+ output := string(p.Bytes())
+ if !strings.Contains(output, "") {
+ t.Errorf("output should contain , got: %s", output)
+ }
+ if !strings.Contains(output, "") {
+ t.Errorf("output should contain , got: %s", output)
+ }
+ if !strings.Contains(output, "text") {
+ t.Errorf("output should contain text, got: %s", output)
+ }
+}
+
+func TestEscapeXMLString(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "ascii no escape needed",
+ input: "hello world",
+ want: "hello world",
+ },
+ {
+ name: "with ampersand",
+ input: "foo & bar",
+ want: "foo & bar",
+ },
+ {
+ name: "with less than",
+ input: "a < b",
+ want: "a < b",
+ },
+ {
+ name: "with greater than",
+ input: "a > b",
+ want: "a > b",
+ },
+ {
+ name: "with quotes",
+ input: `"quoted"`,
+ want: ""quoted"",
+ },
+ {
+ name: "with single quotes",
+ input: "'quoted'",
+ want: "'quoted'",
+ },
+ {
+ name: "with tab",
+ input: "a\tb",
+ want: "a b",
+ },
+ {
+ name: "with newline",
+ input: "a\nb",
+ want: "a
b",
+ },
+ {
+ name: "with carriage return",
+ input: "a\rb",
+ want: "a
b",
+ },
+ {
+ name: "unicode chars",
+ input: "日本語",
+ want: "日本語",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := core.NewHandle(core.ColorOff).Stderr()
+ escapeXMLString(p, tt.input)
+ got := string(p.Bytes())
+ if got != tt.want {
+ t.Errorf("escapeXMLString() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}