Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions internal/format/xml.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func FormatXML(buf []byte, w *core.Printer) error {
}
stack = append(stack, false)
case xml.EndElement:
if len(stack) == 0 {
// Malformed XML: more closing tags than opening tags.
// Skip indent adjustment, just write the closing tag.
w.WriteString("</")
writeXMLTagName(w, t.Name.Local)
w.WriteString(">\n")
continue
}
last := stack[len(stack)-1]
stack = stack[:len(stack)-1]

Expand Down Expand Up @@ -162,6 +170,21 @@ func writeXMLProcInst(p *core.Printer, inst []byte) {
// Mostly taken from the Go encoding/xml package in the standard library:
// https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/encoding/xml/xml.go;l=1964-1999
func escapeXMLString(p *core.Printer, s string) {
// Fast path: check if string needs escaping.
needsEscape := false
for i := 0; i < len(s); i++ {
c := s[i]
if c == '"' || c == '\'' || c == '&' || c == '<' || c == '>' ||
c == '\t' || c == '\n' || c == '\r' || c >= 0x80 {
needsEscape = true
break
}
}
if !needsEscape {
p.WriteString(s)
return
}

var esc string
var last int
for i := 0; i < len(s); {
Expand Down
147 changes: 147 additions & 0 deletions internal/format/xml_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package format

import (
"strings"
"testing"

"github.com/ryanfowler/fetch/internal/core"
)

func TestFormatXML(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "valid simple xml",
input: "<root><child>text</child></root>",
wantErr: false,
},
{
name: "valid nested xml",
input: "<a><b><c>text</c></b></a>",
wantErr: false,
},
{
name: "valid xml with attributes",
input: `<root attr="value"><child id="1">text</child></root>`,
wantErr: false,
},
{
name: "malformed xml extra closing tag at start",
input: "</foo><bar></bar>",
wantErr: true, // XML decoder catches this
},
{
name: "malformed xml extra closing tag at end",
input: "<a></a></a>",
wantErr: true, // XML decoder catches this
},
{
name: "malformed xml multiple extra closing tags",
input: "</x></y></z>",
wantErr: true, // XML decoder catches this
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := core.NewHandle(core.ColorOff).Stderr()
err := FormatXML([]byte(tt.input), p)
if (err != nil) != tt.wantErr {
t.Errorf("FormatXML() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func TestFormatXMLOutput(t *testing.T) {
input := "<root><child>text</child></root>"
p := core.NewHandle(core.ColorOff).Stderr()
err := FormatXML([]byte(input), p)
if err != nil {
t.Fatalf("FormatXML() error = %v", err)
}

output := string(p.Bytes())
if !strings.Contains(output, "<root>") {
t.Errorf("output should contain <root>, got: %s", output)
}
if !strings.Contains(output, "</root>") {
t.Errorf("output should contain </root>, got: %s", output)
}
if !strings.Contains(output, "text") {
t.Errorf("output should contain text, got: %s", output)
}
}

func TestEscapeXMLString(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "ascii no escape needed",
input: "hello world",
want: "hello world",
},
{
name: "with ampersand",
input: "foo & bar",
want: "foo &amp; bar",
},
{
name: "with less than",
input: "a < b",
want: "a &lt; b",
},
{
name: "with greater than",
input: "a > b",
want: "a &gt; b",
},
{
name: "with quotes",
input: `"quoted"`,
want: "&quot;quoted&quot;",
},
{
name: "with single quotes",
input: "'quoted'",
want: "&apos;quoted&apos;",
},
{
name: "with tab",
input: "a\tb",
want: "a&#x9;b",
},
{
name: "with newline",
input: "a\nb",
want: "a&#xA;b",
},
{
name: "with carriage return",
input: "a\rb",
want: "a&#xD;b",
},
{
name: "unicode chars",
input: "日本語",
want: "日本語",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := core.NewHandle(core.ColorOff).Stderr()
escapeXMLString(p, tt.input)
got := string(p.Bytes())
if got != tt.want {
t.Errorf("escapeXMLString() = %q, want %q", got, tt.want)
}
})
}
}