Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dist/examples/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ tls = false
[etcd3]
endpoints = [ "http://127.0.0.1:2379" ]

# MemoryDB configuration
[memdb]
enabled = false

# Lock configuration, base reboot group
[lock]
default_group_name = "default"
Expand Down
7 changes: 5 additions & 2 deletions internal/cli/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ func verbosityLevel(verbCount int) logrus.Level {

// validateSettings sanity-checks all settings
func validateSettings(cfg config.Settings) error {
if len(cfg.EtcdEndpoints) == 0 {
return errors.New("no etcd3 endpoints configured")
if len(cfg.EtcdEndpoints) == 0 && !cfg.MemDBEnabled {
return errors.New("no etcd3 endpoints configured and MemDB is not enabled")
}
if len(cfg.EtcdEndpoints) > 0 && cfg.MemDBEnabled {
return errors.New("both etcd3 endpoints and MemDB are configured, choose one")
}
if len(cfg.LockGroups) == 0 {
return errors.New("no lock-groups configured")
Expand Down
2 changes: 1 addition & 1 deletion internal/cli/ex-get-slots.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func runGetSlots(cmd *cobra.Command, cmdArgs []string) error {
ctx, cancel := context.WithTimeout(context.Background(), runSettings.EtcdTxnTimeout)
defer cancel()

manager, err := lock.NewManager(ctx, runSettings.EtcdEndpoints, runSettings.ClientCertPubPath, runSettings.ClientCertKeyPath, runSettings.EtcdTxnTimeout, group, maxSlots)
manager, err := lock.NewManager(ctx, runSettings, group, maxSlots)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions internal/config/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Settings struct {
ClientCertKeyPath string
EtcdTxnTimeout time.Duration

MemDBEnabled bool

LockGroups map[string]uint64
}

Expand Down Expand Up @@ -55,6 +57,8 @@ func defaultSettings() Settings {
EtcdEndpoints: []string{},
EtcdTxnTimeout: time.Duration(3) * time.Second,

MemDBEnabled: false,

LockGroups: make(map[string]uint64),
}
}
19 changes: 19 additions & 0 deletions internal/config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type tomlConfig struct {
Service *serviceSection `toml:"service"`
Status *statusSection `toml:"status"`
Etcd3 *etcd3Section `toml:"etcd3"`
MemDB *MemDBSection `toml:"memdb"`
Lock *lockSection `toml:"lock"`
}

Expand All @@ -37,6 +38,11 @@ type etcd3Section struct {
ClientCertKeyPath string `toml:"client_cert_key_path"`
}

// MemDBSection holds the optional `memdb` fragment
type MemDBSection struct {
Enabled *bool `toml:"enabled"`
}

// lockSection holds the optional `lock` fragment
type lockSection struct {
DefaultGroupName *string `toml:"default_group_name"`
Expand Down Expand Up @@ -77,6 +83,9 @@ func mergeToml(settings *Settings, cfg tomlConfig) {
if cfg.Etcd3 != nil {
mergeEtcd(settings, *cfg.Etcd3)
}
if cfg.MemDB != nil {
mergeMemDB(settings, *cfg.MemDB)
}
if cfg.Lock != nil {
mergeLock(settings, *cfg.Lock)
}
Expand Down Expand Up @@ -136,6 +145,16 @@ func mergeEtcd(settings *Settings, cfg etcd3Section) {
}
}

func mergeMemDB(settings *Settings, cfg MemDBSection) {
if settings == nil {
return
}

if cfg.Enabled != nil {
settings.MemDBEnabled = *cfg.Enabled
}
}

func mergeLock(settings *Settings, cfg lockSection) {
if settings == nil {
return
Expand Down
213 changes: 213 additions & 0 deletions internal/lock/etcdlock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package lock

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"time"

transport "go.etcd.io/etcd/client/pkg/v3/transport"
clientv3 "go.etcd.io/etcd/client/v3"
)

const (
keyTemplate = "com.coreos.airlock/groups/%s/v1/semaphore"
)

var (
// ErrNilEtcdManager is returned on nil manager
ErrNilEtcdManager = errors.New("nil EtcdManager")
)

// EtcdManager takes care of locking for clients
type EtcdManager struct {
client *clientv3.Client
keyPath string
}

// NewEtcdManager returns a new lock manager, ensuring the underlying semaphore is initialized.
func NewEtcdManager(ctx context.Context, etcdURLs []string, certPubPath string, certKeyPath string, txnTimeoutMs time.Duration, group string, slots uint64) (*EtcdManager, error) {
tlsInfo := transport.TLSInfo{
CertFile: certPubPath,
KeyFile: certKeyPath,
}

tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
return nil, err
}

client, err := clientv3.New(clientv3.Config{
Endpoints: etcdURLs,
DialTimeout: time.Duration(txnTimeoutMs) * time.Millisecond,
TLS: tlsConfig,
})
if err != nil {
return nil, err
}

keyPath := fmt.Sprintf(keyTemplate, url.QueryEscape(group))
manager := EtcdManager{client, keyPath}

if err := manager.ensureInit(ctx, slots); err != nil {
return nil, err
}

return &manager, nil
}

// RecursiveLock adds this lock `id` as a holder of the semaphore
//
// It will return an error if there is a problem getting or setting the
// semaphore, or if the maximum number of holders has been reached.
func (m *EtcdManager) RecursiveLock(ctx context.Context, id string) (*Semaphore, error) {
sem, version, err := m.get(ctx)
if err != nil {
return nil, err
}

held, err := sem.RecursiveLock(id)
if err != nil {
return nil, err
}
if held {
return sem, nil
}

if err := m.set(ctx, sem, version); err != nil {
return nil, err
}

return sem, nil
}

// UnlockIfHeld removes this lock `id` as a holder of the semaphore
//
// It returns an error if there is a problem getting or setting the semaphore.
func (m *EtcdManager) UnlockIfHeld(ctx context.Context, id string) (*Semaphore, error) {
sem, version, err := m.get(ctx)
if err != nil {
return nil, err
}

if err := sem.UnlockIfHeld(id); err != nil {
return nil, err
}

if err := m.set(ctx, sem, version); err != nil {
return nil, err
}

return sem, nil
}

// FetchSemaphore fetches current semaphore version
func (m *EtcdManager) FetchSemaphore(ctx context.Context) (*Semaphore, error) {
semaphore, _, err := m.get(ctx)
if err != nil {
return nil, err
}

return semaphore, nil
}

// Close reaps all running goroutines
func (m *EtcdManager) Close() {
if m == nil {
return
}

m.client.Close()
}

// ensureInit initialize the semaphore in etcd, if it does not exist yet
func (m *EtcdManager) ensureInit(ctx context.Context, slots uint64) error {
if m == nil {
return ErrNilEtcdManager
}

sem := NewSemaphore(slots)
semValue, err := sem.String()
if err != nil {
return err
}

_, err = m.client.Txn(ctx).If(
// version=0 means that the key does not exist.
clientv3.Compare(clientv3.Version(m.keyPath), "=", 0),
).Then(
clientv3.OpPut(m.keyPath, semValue),
).Commit()

if err != nil {
return err
}
return nil
}

// get returns the current semaphore value and version, or an error
func (m *EtcdManager) get(ctx context.Context) (*Semaphore, int64, error) {
resp, err := m.client.Get(ctx, m.keyPath)
if err != nil {
return nil, 0, err
}
if resp.Count != 1 {
return nil, 0, fmt.Errorf("unexpected number of results: %d", resp.Count)
}

var data []byte
var version int64
for _, kv := range resp.Kvs {
data = kv.Value
version = kv.Version
break
}
if version == 0 {
return nil, 0, errors.New("key at version 0")
}
if len(data) == 0 {
return nil, 0, errors.New("empty semaphore value")
}

sem := &Semaphore{}
err = json.Unmarshal(data, sem)
if err != nil {
return nil, 0, err
}

return sem, version, nil
}

// set updates the semaphore in etcd, if `version` matches the one previously observed
func (m *EtcdManager) set(ctx context.Context, sem *Semaphore, version int64) error {
if m == nil {
return ErrNilEtcdManager
}
if sem == nil {
return ErrNilSemaphore
}

data, err := json.Marshal(sem)
if err != nil {
return err
}

// Conditionally Put if version in etcd is still the same we observed.
// If the condition is not met, the transaction will return as "not succeeding".
resp, err := m.client.Txn(ctx).If(
clientv3.Compare(clientv3.Version(m.keyPath), "=", version),
).Then(
clientv3.OpPut(m.keyPath, string(data)),
).Commit()

if err != nil {
return err
}
if !resp.Succeeded {
return errors.New("conflict on semaphore detected, aborting")
}

return nil
}
Loading
Loading