-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtaskgroup.go
More file actions
170 lines (154 loc) · 3.77 KB
/
taskgroup.go
File metadata and controls
170 lines (154 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package taskgroup
import (
"context"
"errors"
"slices"
"sync"
)
var ErrUnknownTaskType = errors.New("Unknown task type")
type taskProcess[R any] func(ctx context.Context) (R, error)
type TaskGroup[R any] struct {
wg sync.WaitGroup
lock sync.Mutex
// results are stored in the same order in which
// the tasks were added
results []R
tasks []taskProcess[R]
cancelCtx context.CancelFunc
ctx context.Context
errCh chan error
// By the default ConcurrencyLimit is 0, no limits
ConcurrencyLimit int
AppCtx context.Context
}
func NewTaskGroup[R any](ctx context.Context) *TaskGroup[R] {
localCtx, cancel := context.WithCancel(ctx)
return &TaskGroup[R]{
AppCtx: ctx,
results: make([]R, 0, 10),
tasks: make([]taskProcess[R], 0, 10),
ctx: localCtx,
cancelCtx: cancel,
errCh: make(chan error, 1),
}
}
func (g *TaskGroup[R]) Limit(concurrencyLimit int) *TaskGroup[R] {
g.ConcurrencyLimit = concurrencyLimit
return g
}
func castTask[R any](task any) (t taskProcess[R], err error) {
switch f := any(task).(type) {
case func() R:
t = func(_ context.Context) (R, error) {
res := f()
return res, nil
}
return
case func(ctx context.Context) R:
t = func(ctx context.Context) (R, error) {
res := f(ctx)
return res, nil
}
return
case func() (R, error):
t = func(_ context.Context) (R, error) {
res, err := f()
return res, err
}
return
case func(ctx context.Context) (R, error):
return f, nil
default:
return nil, ErrUnknownTaskType
}
}
func (g *TaskGroup[R]) Add(anyTask any) error {
task, err := castTask[R](anyTask)
if err != nil {
return err
}
g.lock.Lock()
g.tasks = append(g.tasks, task)
g.lock.Unlock()
return nil
}
func (g *TaskGroup[R]) Clear() {
g.tasks = g.tasks[:0]
g.results = g.results[:0]
}
func (g *TaskGroup[R]) startProcessing(i int, task taskProcess[R]) {
defer g.wg.Done()
v, err := task(g.ctx)
if err != nil {
select {
case <-g.ctx.Done():
return
// at least one error stops the whole processing
case g.errCh <- err:
// stop other tasks
g.cancelCtx()
return
}
}
g.results[i] = v
}
// Start triggers added tasks to work concurrently if ConcurrencyLimit > 1.
// If ConcurrencyLimit is 0, there will be no limits.
//
// Start will return a channel size of ConcurrencyLimit.
// The channel will be used to send the results of the added tasks
// by the order they were added.
//
// Task processing is done by chunks size of ConcurrencyLimit.
// TaskGroup will not start the processing of the next chunk unless all the results of the previous chunk
// were put in the result channel.
//
// Task processing will be interrupted if at least one error occurs. The results channel will be closed.
func (g *TaskGroup[R]) Start() <-chan R {
_ = g.Error()
// Concurrency limitation
concLimit := len(g.tasks)
if g.ConcurrencyLimit > 0 && len(g.tasks) > g.ConcurrencyLimit {
concLimit = g.ConcurrencyLimit
}
resCh := make(chan R, concLimit)
go g.processChunks(concLimit, resCh)
return resCh
}
func (g *TaskGroup[R]) processChunks(n int, resCh chan R) {
g.lock.Lock()
defer g.lock.Unlock()
defer close(resCh)
defer g.Clear()
if len(g.tasks) == 0 {
return
}
// prepare results
if n > cap(g.results) {
g.results = make([]R, 0, n)
}
for chunk := range slices.Chunk(g.tasks, n) {
g.results = g.results[:min(len(chunk), n)]
for i, task := range chunk {
g.wg.Add(1)
go g.startProcessing(i, task)
}
g.wg.Wait()
if len(g.errCh) != 0 || g.ctx.Err() != nil {
return
}
for _, res := range g.results {
resCh <- res
}
}
}
// You must call Error to check for any errors. Otherwise the error will be ignored.
func (g *TaskGroup[R]) Error() error {
if len(g.errCh) > 0 {
return <-g.errCh
}
if g.AppCtx.Err() != nil {
return g.AppCtx.Err()
}
return nil
}