Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: go mod download

- name: "Run tests"
run: go test -count 10 -v ./...
run: go test -race -count 10 -v ./...

- name: "Build"
run: go build -o ./bin/krun .
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import (

func main() {
queue := krun.New(&krun.Config{
Size: 5, // number of workers
WaitSleep: time.Microsecond,
Size: 5, // number of workers
})

job := func(ctx context.Context) (interface{}, error) {
Expand Down
3 changes: 1 addition & 2 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (

func main() {
queue := krun.New(&krun.Config{
Size: 5, // number of workers
WaitSleep: time.Microsecond,
Size: 5, // number of workers
})

job := func(ctx context.Context) (interface{}, error) {
Expand Down
101 changes: 75 additions & 26 deletions krun.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@ import (
"context"
"errors"
"sync"
"time"
)

// ErrPoolClosed it is closed (hahah)
// ErrPoolClosed is returned when operations are attempted on a closed pool.
var ErrPoolClosed = errors.New("pool's closed")

// Result represents the result of a job execution.
type Result struct {
Data interface{}
Error error
}

// Job represents a function that can be executed by the worker pool.
// It receives a context and returns a result and an error.
type Job func(ctx context.Context) (interface{}, error)

// Krun is the interface for a worker pool that can execute jobs concurrently.
type Krun interface {
// Run executes a job and returns a channel that will receive the result.
Run(ctx context.Context, f Job) <-chan *Result
// Wait blocks until all running jobs complete or the context is cancelled.
Wait(ctx context.Context)
// Size returns the number of workers in the pool.
Size() int
// Close shuts down the pool, waiting for all running jobs to complete.
// Returns ErrPoolClosed if called multiple times.
Close() error
}

Expand All @@ -37,22 +46,28 @@ type worker struct {
result chan *Result
}

// Config configures a new Krun instance.
type Config struct {
Size int
WaitSleep time.Duration
Size int
}

// New creates a new Krun worker pool with the given configuration.
func New(cfg *Config) Krun {
size := 1
if cfg != nil && cfg.Size > 0 {
size = cfg.Size
}

k := &krun{
poolSize: cfg.Size,
poolSize: size,
closed: false,

workers: make(chan *worker, cfg.Size),
workers: make(chan *worker, size),
wg: sync.WaitGroup{},
mu: sync.RWMutex{},
}

for i := 0; i < cfg.Size; i++ {
for i := 0; i < size; i++ {
k.push(&worker{})
}

Expand All @@ -67,18 +82,46 @@ func (k *krun) Size() int {
}

func (k *krun) Run(ctx context.Context, f Job) <-chan *Result {
// get worker from the channel
w := k.pop()
k.wg.Add(1)

// assign Job to the worker and Run it
cr := make(chan *Result, 1)
w.job = f
w.result = cr
go k.work(ctx, w)

// return channel to the caller
return cr
// Check if context is already cancelled before trying to get a worker
if ctx.Err() != nil {
cr <- &Result{Error: ctx.Err()}
return cr
}

// get worker from the channel
select {
case <-ctx.Done():
cr <- &Result{Error: ctx.Err()}
return cr
case w, ok := <-k.workers:
if !ok {
// Channel was closed
cr <- &Result{Error: ErrPoolClosed}
return cr
}

// Check if pool was closed after getting worker
k.mu.RLock()
closed := k.closed
k.mu.RUnlock()

if closed {
// Pool was closed, discard worker
cr <- &Result{Error: ErrPoolClosed}
return cr
}

k.wg.Add(1)

// assign Job to the worker and Run it
w.job = f
w.result = cr
go k.work(ctx, w)

return cr
}
}

func (k *krun) Wait(ctx context.Context) {
Expand Down Expand Up @@ -106,12 +149,12 @@ func (k *krun) Close() error {
k.closed = true
k.mu.Unlock()

// Close worker channel first to unblock any waiting Run() calls
close(k.workers)

// Wait for all work to complete
k.wg.Wait()

// Close worker channel
close(k.workers)

return nil
}

Expand All @@ -121,15 +164,21 @@ func (k *krun) work(ctx context.Context, w *worker) {

// send Result into the caller channel
w.result <- &Result{d, err}
k.wg.Done()

// return worker to Krun
// return worker to Krun if pool is still open
k.push(w)
k.wg.Done()
}

func (k *krun) push(w *worker) {
k.workers <- w
}
k.mu.RLock()
closed := k.closed
k.mu.RUnlock()

if closed {
// Pool is closed, discard worker
return
}

func (k *krun) pop() *worker {
return <-k.workers
k.workers <- w
}
161 changes: 158 additions & 3 deletions krun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@ func TestNew(t *testing.T) {
t.Fatalf("Expected *krun, got %T", v)
}
})

t.Run("handles nil Config", func(t *testing.T) {
k := New(nil)
if k == nil {
t.Fatalf("Expected Krun, got nil")
}
if k.Size() != 1 {
t.Fatalf("Expected default size 1, got %d", k.Size())
}
})

t.Run("handles Size 0", func(t *testing.T) {
k := New(&Config{Size: 0})
if k == nil {
t.Fatalf("Expected Krun, got nil")
}
if k.Size() != 1 {
t.Fatalf("Expected default size 1, got %d", k.Size())
}
})

t.Run("handles negative Size", func(t *testing.T) {
k := New(&Config{Size: -5})
if k == nil {
t.Fatalf("Expected Krun, got nil")
}
if k.Size() != 1 {
t.Fatalf("Expected default size 1, got %d", k.Size())
}
})
}

func TestKrun_Size(t *testing.T) {
Expand Down Expand Up @@ -68,8 +98,6 @@ func TestRun(t *testing.T) {
if tp != "my-string" {
t.Fatalf("expected \"my-string\", received: %s", tp)
}

break
default:
t.Fatalf("expected string, got %t", tp)
}
Expand Down Expand Up @@ -111,7 +139,7 @@ func TestRun(t *testing.T) {

select {
case e := <-errChan:
t.Fatalf(e.Error())
t.Fatalf("Expected nil, got %v", e)
case <-time.After(time.Millisecond):
return
}
Expand All @@ -135,6 +163,96 @@ func TestRun(t *testing.T) {
t.Fatalf("Expected nil, got %v", d.Data)
}
})

t.Run("returns error if context already cancelled", func(t *testing.T) {
k := New(&Config{Size: 1})

ctx, cancel := context.WithCancel(context.Background())
cancel()

r := <-k.Run(ctx, func(ctx context.Context) (interface{}, error) {
return "should not run", nil
})

if r.Error == nil {
t.Fatalf("Expected context cancelled error, got nil")
}
if r.Error != ctx.Err() {
t.Fatalf("Expected context cancelled error, got %v", r.Error)
}
})

t.Run("returns error if pool is closed", func(t *testing.T) {
k := New(&Config{Size: 1})

// Close the pool
if err := k.Close(); err != nil {
t.Fatalf("Expected nil on first close, got %v", err)
}

// Try to run a job after close
r := <-k.Run(context.Background(), func(ctx context.Context) (interface{}, error) {
return "should not run", nil
})

if r.Error == nil {
t.Fatalf("Expected ErrPoolClosed, got nil")
}
if !errors.Is(r.Error, ErrPoolClosed) {
t.Fatalf("Expected ErrPoolClosed, got %v", r.Error)
}
})

t.Run("handles Run() after Close() while job waiting", func(t *testing.T) {
// t.SkipNow()
// return
k := New(&Config{Size: 1})

ctx := context.Background()

// Start a job that takes time
started := make(chan struct{})
jobDone := make(chan struct{})
_ = k.Run(ctx, func(ctx context.Context) (interface{}, error) {
started <- struct{}{}
<-jobDone
return "done", nil
})

// Wait for job to start
<-started

// Close in another goroutine
closeDone := make(chan error)
go func() {
closeDone <- k.Close()
}()

// Try to run another job while closing
r := <-k.Run(ctx, func(ctx context.Context) (interface{}, error) {
return "should not run", nil
})

if r.Error == nil {
t.Fatalf("Expected ErrPoolClosed, got nil")
}
if !errors.Is(r.Error, ErrPoolClosed) {
t.Fatalf("Expected ErrPoolClosed, got %v", r.Error)
}

// Finish the running job
close(jobDone)

// Wait for close to complete
select {
case err := <-closeDone:
if err != nil {
t.Fatalf("Expected nil on close, got %v", err)
}
case <-time.After(50 * time.Millisecond):
t.Fatalf("Close did not complete in time")
}
})
}

func TestKrun_Wait(t *testing.T) {
Expand Down Expand Up @@ -296,4 +414,41 @@ func TestKrun_Close(t *testing.T) {
t.Fatalf("Close did not complete in time")
}
})

t.Run("concurrent Close() calls", func(t *testing.T) {
k := New(&Config{Size: 5})

results := make(chan error, 10)
for i := 0; i < 10; i++ {
go func() {
results <- k.Close()
}()
}

// One should succeed, rest should get ErrPoolClosed
successCount := 0
errorCount := 0

for i := 0; i < 10; i++ {
select {
case err := <-results:
if err == nil {
successCount++
} else if errors.Is(err, ErrPoolClosed) {
errorCount++
} else {
t.Fatalf("Unexpected error: %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("Close did not complete in time")
}
}

if successCount != 1 {
t.Fatalf("Expected exactly 1 successful close, got %d", successCount)
}
if errorCount != 9 {
t.Fatalf("Expected 9 ErrPoolClosed, got %d", errorCount)
}
})
}