Skip to content

Commit 59e58f8

Browse files
authored
Support day and week in duration parsing (#87)
1 parent 2c3a4e6 commit 59e58f8

File tree

4 files changed

+300
-7
lines changed

4 files changed

+300
-7
lines changed

internal/ssh/flags.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ import (
77
"io"
88
"strings"
99
"time"
10+
11+
"github.com/robherley/snips.sh/internal/timeutil"
1012
)
1113

1214
var (
1315
ErrFlagRequied = errors.New("flag required")
16+
ErrFlagParse = errors.New("parse error")
1417
)
1518

1619
type UploadFlags struct {
@@ -27,7 +30,7 @@ func (uf *UploadFlags) Parse(out io.Writer, args []string) error {
2730

2831
uf.BoolVar(&uf.Private, "private", false, "only accessible via creator or signed urls (optional)")
2932
uf.StringVar(&uf.Extension, "ext", "", "set the file extension (optional)")
30-
uf.DurationVar(&uf.TTL, "ttl", 0, "lifetime of the signed url (optional)")
33+
addDurationFlag(uf.FlagSet, &uf.TTL, "ttl", 0, "lifetime of the signed url (optional)")
3134

3235
if err := uf.FlagSet.Parse(args); err != nil {
3336
return err
@@ -52,7 +55,7 @@ func (sf *SignFlags) Parse(out io.Writer, args []string) error {
5255
sf.FlagSet = flag.NewFlagSet("", flag.ContinueOnError)
5356
sf.SetOutput(out)
5457

55-
sf.DurationVar(&sf.TTL, "ttl", 0, "lifetime of the signed url")
58+
addDurationFlag(sf.FlagSet, &sf.TTL, "ttl", 0, "lifetime of the signed url")
5659

5760
if err := sf.FlagSet.Parse(args); err != nil {
5861
return err
@@ -79,3 +82,32 @@ func (df *DeleteFlags) Parse(out io.Writer, args []string) error {
7982

8083
return df.FlagSet.Parse(args)
8184
}
85+
86+
// durationFlagValue is a wrapper around time.Duration that implements the flag.Value interface using a custom parser.
87+
type durationFlagValue time.Duration
88+
89+
// addDurationFlag adds a flag for a time.Duration to the given flag.FlagSet.
90+
func addDurationFlag(fs *flag.FlagSet, p *time.Duration, name string, value time.Duration, usage string) {
91+
*p = value
92+
fs.Var((*durationFlagValue)(p), name, usage)
93+
}
94+
95+
// Set implements the flag.Value interface.
96+
func (d *durationFlagValue) Set(s string) error {
97+
v, err := timeutil.ParseDuration(s)
98+
if err != nil {
99+
err = ErrFlagParse
100+
}
101+
*d = durationFlagValue(v)
102+
return err
103+
}
104+
105+
// Get implements the flag.Getter interface.
106+
func (d *durationFlagValue) Get() any {
107+
return time.Duration(*d)
108+
}
109+
110+
// String implements the flag.Value interface.
111+
func (d *durationFlagValue) String() string {
112+
return (*time.Duration)(d).String()
113+
}

internal/ssh/flags_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ func TestUploadFlags(t *testing.T) {
6363
},
6464
{
6565
name: "private and ttl",
66-
args: []string{"-private", "-ttl", "30s"},
66+
args: []string{"-private", "-ttl", "1w2d3m4s"},
6767
want: ssh.UploadFlags{
6868
Private: true,
6969
Extension: "",
70-
TTL: time.Duration(30),
70+
TTL: 1*7*24*time.Hour + 2*24*time.Hour + 3*time.Minute + 4*time.Second,
7171
},
7272
},
7373
{
@@ -76,7 +76,7 @@ func TestUploadFlags(t *testing.T) {
7676
want: ssh.UploadFlags{
7777
Private: false,
7878
Extension: "",
79-
TTL: time.Duration(30),
79+
TTL: 30 * time.Second,
8080
},
8181
err: ssh.ErrFlagRequied,
8282
},
@@ -111,9 +111,9 @@ func TestSignFlags(t *testing.T) {
111111
},
112112
{
113113
name: "ttl",
114-
args: []string{"-ttl", "1h"},
114+
args: []string{"-ttl", "1w2d3m4s"},
115115
want: ssh.SignFlags{
116-
TTL: 1 * time.Hour,
116+
TTL: 1*7*24*time.Hour + 2*24*time.Hour + 3*time.Minute + 4*time.Second,
117117
},
118118
},
119119
}

internal/timeutil/duration.go

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package timeutil
2+
3+
import (
4+
"errors"
5+
"time"
6+
)
7+
8+
// This is verbatim from time.ParseDuration, with some additions to the unit map
9+
10+
// Copyright 2010 The Go Authors. All rights reserved.
11+
// Use of this source code is governed by a BSD-style
12+
// license that can be found in the LICENSE file.
13+
14+
var unitMap = map[string]uint64{
15+
"ns": uint64(time.Nanosecond),
16+
"us": uint64(time.Microsecond),
17+
"µs": uint64(time.Microsecond), // U+00B5 = micro symbol
18+
"μs": uint64(time.Microsecond), // U+03BC = Greek letter mu
19+
"ms": uint64(time.Millisecond),
20+
"s": uint64(time.Second),
21+
"m": uint64(time.Minute),
22+
"h": uint64(time.Hour),
23+
// additional units added below
24+
"d": uint64(time.Hour) * 24,
25+
"w": uint64(time.Hour) * 24 * 7,
26+
}
27+
28+
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
29+
// that package, since we can't take a dependency on either.
30+
const (
31+
lowerhex = "0123456789abcdef"
32+
runeSelf = 0x80
33+
runeError = '\uFFFD'
34+
)
35+
36+
func quote(s string) string {
37+
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
38+
buf[0] = '"'
39+
for i, c := range s {
40+
if c >= runeSelf || c < ' ' {
41+
// This means you are asking us to parse a time.Duration or
42+
// time.Location with unprintable or non-ASCII characters in it.
43+
// We don't expect to hit this case very often. We could try to
44+
// reproduce strconv.Quote's behavior with full fidelity but
45+
// given how rarely we expect to hit these edge cases, speed and
46+
// conciseness are better.
47+
var width int
48+
if c == runeError {
49+
width = 1
50+
if i+2 < len(s) && s[i:i+3] == string(runeError) {
51+
width = 3
52+
}
53+
} else {
54+
width = len(string(c))
55+
}
56+
for j := 0; j < width; j++ {
57+
buf = append(buf, `\x`...)
58+
buf = append(buf, lowerhex[s[i+j]>>4])
59+
buf = append(buf, lowerhex[s[i+j]&0xF])
60+
}
61+
} else {
62+
if c == '"' || c == '\\' {
63+
buf = append(buf, '\\')
64+
}
65+
buf = append(buf, string(c)...)
66+
}
67+
}
68+
buf = append(buf, '"')
69+
return string(buf)
70+
}
71+
72+
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
73+
74+
// leadingInt consumes the leading [0-9]* from s.
75+
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
76+
i := 0
77+
for ; i < len(s); i++ {
78+
c := s[i]
79+
if c < '0' || c > '9' {
80+
break
81+
}
82+
if x > 1<<63/10 {
83+
// overflow
84+
return 0, rem, errLeadingInt
85+
}
86+
x = x*10 + uint64(c) - '0'
87+
if x > 1<<63 {
88+
// overflow
89+
return 0, rem, errLeadingInt
90+
}
91+
}
92+
return x, s[i:], nil
93+
}
94+
95+
// leadingFraction consumes the leading [0-9]* from s.
96+
// It is used only for fractions, so does not return an error on overflow,
97+
// it just stops accumulating precision.
98+
func leadingFraction(s string) (x uint64, scale float64, rem string) {
99+
i := 0
100+
scale = 1
101+
overflow := false
102+
for ; i < len(s); i++ {
103+
c := s[i]
104+
if c < '0' || c > '9' {
105+
break
106+
}
107+
if overflow {
108+
continue
109+
}
110+
if x > (1<<63-1)/10 {
111+
// It's possible for overflow to give a positive number, so take care.
112+
overflow = true
113+
continue
114+
}
115+
y := x*10 + uint64(c) - '0'
116+
if y > 1<<63 {
117+
overflow = true
118+
continue
119+
}
120+
x = y
121+
scale *= 10
122+
}
123+
return x, scale, s[i:]
124+
}
125+
126+
// ParseDuration parses a duration string.
127+
// A duration string is a possibly signed sequence of
128+
// decimal numbers, each with optional fraction and a unit suffix,
129+
// such as "300ms", "-1.5h" or "2h45m".
130+
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
131+
func ParseDuration(s string) (time.Duration, error) {
132+
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
133+
orig := s
134+
var d uint64
135+
neg := false
136+
137+
// Consume [-+]?
138+
if s != "" {
139+
c := s[0]
140+
if c == '-' || c == '+' {
141+
neg = c == '-'
142+
s = s[1:]
143+
}
144+
}
145+
// Special case: if all that is left is "0", this is zero.
146+
if s == "0" {
147+
return 0, nil
148+
}
149+
if s == "" {
150+
return 0, errors.New("time: invalid duration " + quote(orig))
151+
}
152+
for s != "" {
153+
var (
154+
v, f uint64 // integers before, after decimal point
155+
scale float64 = 1 // value = v + f/scale
156+
)
157+
158+
var err error
159+
160+
// The next character must be [0-9.]
161+
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
162+
return 0, errors.New("time: invalid duration " + quote(orig))
163+
}
164+
// Consume [0-9]*
165+
pl := len(s)
166+
v, s, err = leadingInt(s)
167+
if err != nil {
168+
return 0, errors.New("time: invalid duration " + quote(orig))
169+
}
170+
pre := pl != len(s) // whether we consumed anything before a period
171+
172+
// Consume (\.[0-9]*)?
173+
post := false
174+
if s != "" && s[0] == '.' {
175+
s = s[1:]
176+
pl := len(s)
177+
f, scale, s = leadingFraction(s)
178+
post = pl != len(s)
179+
}
180+
if !pre && !post {
181+
// no digits (e.g. ".s" or "-.s")
182+
return 0, errors.New("time: invalid duration " + quote(orig))
183+
}
184+
185+
// Consume unit.
186+
i := 0
187+
for ; i < len(s); i++ {
188+
c := s[i]
189+
if c == '.' || '0' <= c && c <= '9' {
190+
break
191+
}
192+
}
193+
if i == 0 {
194+
return 0, errors.New("time: missing unit in duration " + quote(orig))
195+
}
196+
u := s[:i]
197+
s = s[i:]
198+
unit, ok := unitMap[u]
199+
if !ok {
200+
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
201+
}
202+
if v > 1<<63/unit {
203+
// overflow
204+
return 0, errors.New("time: invalid duration " + quote(orig))
205+
}
206+
v *= unit
207+
if f > 0 {
208+
// float64 is needed to be nanosecond accurate for fractions of hours.
209+
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
210+
v += uint64(float64(f) * (float64(unit) / scale))
211+
if v > 1<<63 {
212+
// overflow
213+
return 0, errors.New("time: invalid duration " + quote(orig))
214+
}
215+
}
216+
d += v
217+
if d > 1<<63 {
218+
return 0, errors.New("time: invalid duration " + quote(orig))
219+
}
220+
}
221+
if neg {
222+
return -time.Duration(d), nil
223+
}
224+
if d > 1<<63-1 {
225+
return 0, errors.New("time: invalid duration " + quote(orig))
226+
}
227+
return time.Duration(d), nil
228+
}

internal/timeutil/duration_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package timeutil_test
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/robherley/snips.sh/internal/timeutil"
8+
)
9+
10+
func TestParseDuration(t *testing.T) {
11+
cases := []struct {
12+
duration string
13+
want time.Duration
14+
}{
15+
{"1d", 24 * time.Hour},
16+
{"1d12h30m15s", 24*time.Hour + 12*time.Hour + 30*time.Minute + 15*time.Second},
17+
{"1w", 7 * 24 * time.Hour},
18+
{"1w1d12h", 8*24*time.Hour + 12*time.Hour},
19+
}
20+
21+
for _, tc := range cases {
22+
t.Run(tc.duration, func(t *testing.T) {
23+
got, err := timeutil.ParseDuration(tc.duration)
24+
if err != nil {
25+
t.Fatalf("unexpected error: %v", err)
26+
}
27+
28+
if got != tc.want {
29+
t.Errorf("got %s, want %s", got, tc.want)
30+
}
31+
})
32+
}
33+
}

0 commit comments

Comments
 (0)