diff --git a/internal/coremidi/client.go b/internal/coremidi/client.go index e5c0254..bac1511 100644 --- a/internal/coremidi/client.go +++ b/internal/coremidi/client.go @@ -7,26 +7,90 @@ package coremidi #cgo LDFLAGS: -framework CoreMIDI -framework CoreFoundation -framework CoreServices #include #include + +// goMidiNotify is the exported Go callback; declared here so notifyBridgeFn can call it. +extern void goMidiNotify(int msgID, void *handle); + +// notifyBridgeFn is the C-side bridge registered with MIDIClientCreate. +static void notifyBridgeFn(const MIDINotification *msg, void *refCon) { + goMidiNotify((int)msg->messageID, refCon); +} + +// newMIDIClientWithNotify wraps MIDIClientCreate with our bridge callback. +static OSStatus newMIDIClientWithNotify(CFStringRef name, void *refCon, MIDIClientRef *outClient) { + return MIDIClientCreate(name, notifyBridgeFn, refCon, outClient); +} */ import "C" -import "fmt" +import ( + "fmt" + "runtime/cgo" + "unsafe" +) +// Client wraps a CoreMIDI client reference and an optional notification channel. type Client struct { - client C.MIDIClientRef + client C.MIDIClientRef + NotifyCh <-chan int32 // receives MIDINotificationMessageID values; nil if unsupported + notifyCh chan int32 + notifyHdl cgo.Handle } +// goMidiNotify is called from the C notifyBridgeFn when CoreMIDI fires a +// setup-change notification. It forwards the message ID to the Go channel +// stored in the cgo.Handle. +// +//export goMidiNotify +func goMidiNotify(msgID C.int, handle unsafe.Pointer) { + if handle == nil { + return + } + h := cgo.Handle(uintptr(handle)) + ch, _ := h.Value().(chan int32) + if ch == nil { + return + } + select { + case ch <- int32(msgID): + default: + // drop if consumer is slow — notification channel is best-effort + } +} + +// NewClient creates a CoreMIDI client with the given display name and registers +// a notification callback so that device-change events are forwarded to the +// returned Client.NotifyCh channel. The cgo.Handle embedded in Client must be +// released by calling Dispose() when the client is no longer needed. func NewClient(name string) (client Client, err error) { - var clientRef C.MIDIClientRef + notifyCh := make(chan int32, 16) + h := cgo.NewHandle(notifyCh) stringToCFString(name, func(cfName C.CFStringRef) { - osStatus := C.MIDIClientCreate(cfName, nil, nil, &clientRef) - + var clientRef C.MIDIClientRef + osStatus := C.newMIDIClientWithNotify(cfName, unsafe.Pointer(uintptr(h)), &clientRef) if osStatus != C.noErr { err = fmt.Errorf("%d: failed to create a client", int(osStatus)) } else { - client = Client{clientRef} + client = Client{ + client: clientRef, + NotifyCh: notifyCh, + notifyCh: notifyCh, + notifyHdl: h, + } } }) + if err != nil { + h.Delete() + } return } + +// Dispose releases the cgo.Handle associated with the notification channel. +// It must be called when the client is no longer needed to avoid a memory leak. +func (c *Client) Dispose() { + if c.notifyHdl != 0 { + c.notifyHdl.Delete() + c.notifyHdl = 0 + } +} diff --git a/internal/midi/mididarwin/client_darwin.go b/internal/midi/mididarwin/client_darwin.go index 2e8d987..b9f25ea 100644 --- a/internal/midi/mididarwin/client_darwin.go +++ b/internal/midi/mididarwin/client_darwin.go @@ -27,6 +27,10 @@ type internalPortConnection interface { Disconnect() } +// ClientMid is the macOS CoreMIDI implementation of contracts.ClientMIDI. +// It wraps a CoreMIDI client + input port and routes packets to a buffered +// output channel. A separate notification channel (coremidi.Client.NotifyCh) +// is used by WatchDevices to detect device hot-plug events. type ClientMid struct { logger contracts.Logger eventChannel atomic.Value @@ -43,6 +47,10 @@ type ClientMid struct { closeChOnce sync.Once } +// NewMIDIClient creates a CoreMIDI client and returns a ClientMIDI backed by +// the CoreMIDI framework. options.CoreMIDIConfig.ClientName is used as the +// CoreMIDI client name displayed in Audio MIDI Setup; it defaults to +// "GO MIDI Client" when nil. func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, error) { if options.CoreMIDIConfig == nil { options.CoreMIDIConfig = &contracts.CoreMIDIConfig{ClientName: "GO MIDI Client"} @@ -61,6 +69,7 @@ func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, erro }, nil } +// ListDevices returns all CoreMIDI sources currently visible to the system. func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { sources, err := coremidi.AllSources() if err != nil { @@ -82,6 +91,8 @@ func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { return devices, nil } +// SelectDevice opens an input port connected to the CoreMIDI source at index +// deviceID. Any previous port connection is disconnected first. func (m *ClientMid) SelectDevice(deviceID int) error { m.mu.Lock() defer m.mu.Unlock() @@ -117,6 +128,9 @@ func (m *ClientMid) SelectDevice(deviceID int) error { return nil } +// handleMIDIMessage is the CoreMIDI input-port callback. It converts a raw +// MIDI packet into a contracts.MIDI event, applies the event filter, and sends +// the event to the output channel. Packets with fewer than 3 bytes are dropped. func (m *ClientMid) handleMIDIMessage(source coremidi.Source, packet coremidi.Packet) { m.wg.Add(1) defer m.wg.Done() @@ -159,6 +173,46 @@ func (m *ClientMid) stopLocked() { m.logger.Info("MIDI capture stopped") } +// WatchDevices returns a channel that emits a DeviceEvent each time a MIDI +// device is connected or disconnected. It uses CoreMIDI's native notification +// mechanism (kMIDIMsgObjectAdded / kMIDIMsgObjectRemoved / kMIDIMsgSetupChanged) +// so events are delivered with minimal latency. +func (m *ClientMid) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + evCh := make(chan contracts.DeviceEvent, 16) + + notifyCh := m.client.NotifyCh + if notifyCh == nil { + close(evCh) + return evCh, nil + } + + prev, _ := m.ListDevices() + + go func() { + defer close(evCh) + for { + select { + case <-ctx.Done(): + return + case msgID, ok := <-notifyCh: + if !ok { + return + } + // React to setup changes (1), object added (2), object removed (3). + if msgID < 1 || msgID > 3 { + continue + } + curr, _ := m.ListDevices() + diffDevices(prev, curr, evCh) + prev = curr + } + } + }() + + return evCh, nil +} + +// closeOutCh closes the output channel exactly once. func (m *ClientMid) closeOutCh() { m.closeChOnce.Do(func() { if m.outCh != nil { @@ -167,6 +221,8 @@ func (m *ClientMid) closeOutCh() { }) } +// Stop halts MIDI capture, disconnects the port connection, drains in-flight +// callbacks via wg.Wait, and calls Dispose to release the CoreMIDI cgo.Handle. func (m *ClientMid) Stop() error { m.mu.Lock() if !m.capturing { @@ -179,9 +235,14 @@ func (m *ClientMid) Stop() error { m.wg.Wait() m.closeOutCh() + m.client.Dispose() return nil } +// StartCapture begins streaming MIDI events from the selected device into the +// returned channel. The channel is closed when ctx is cancelled or Stop is +// called. Calling StartCapture while already capturing implicitly calls Stop +// first to reset state. func (m *ClientMid) StartCapture(ctx context.Context) (<-chan contracts.MIDI, error) { if err := m.Stop(); err != nil { return nil, err @@ -209,3 +270,33 @@ func (m *ClientMid) StartCapture(ctx context.Context) (<-chan contracts.MIDI, er return ch, nil } + +// diffDevices compares two device lists and sends DeviceAdded / DeviceRemoved +// events to evCh for each difference. +func diffDevices(prev, curr []contracts.DeviceInfo, evCh chan<- contracts.DeviceEvent) { + for _, d := range curr { + if !containsDevice(prev, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceAdded, Device: d}: + default: + } + } + } + for _, d := range prev { + if !containsDevice(curr, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceRemoved, Device: d}: + default: + } + } + } +} + +func containsDevice(list []contracts.DeviceInfo, d contracts.DeviceInfo) bool { + for _, item := range list { + if item.Name == d.Name && item.Manufacturer == d.Manufacturer { + return true + } + } + return false +} diff --git a/internal/midi/mididarwin/client_dummy.go b/internal/midi/mididarwin/client_dummy.go index 54c1d5c..54bc661 100644 --- a/internal/midi/mididarwin/client_dummy.go +++ b/internal/midi/mididarwin/client_dummy.go @@ -10,6 +10,8 @@ import ( "github.com/leandrodaf/midi/v2/sdk/contracts" ) +// dummyMIDIClient is the no-op ClientMIDI used on non-macOS systems when the +// mididarwin package is selected by the build system. type dummyMIDIClient struct { logger contracts.Logger } @@ -38,3 +40,11 @@ func (m *dummyMIDIClient) Stop() error { m.logger.Warn("Stop called on dummy MIDI client") return nil } + +// WatchDevices returns a channel that is closed when ctx is cancelled. +// No device events are ever emitted — this is a no-op stub. +func (m *dummyMIDIClient) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + ch := make(chan contracts.DeviceEvent) + go func() { <-ctx.Done(); close(ch) }() + return ch, nil +} diff --git a/internal/midi/mididarwin/client_dummy_test.go b/internal/midi/mididarwin/client_dummy_test.go index c94b3bc..581867b 100644 --- a/internal/midi/mididarwin/client_dummy_test.go +++ b/internal/midi/mididarwin/client_dummy_test.go @@ -7,6 +7,7 @@ import ( "context" "io" "testing" + "time" "github.com/leandrodaf/midi/v2/internal/logger" "github.com/leandrodaf/midi/v2/internal/midi/mididarwin" @@ -64,3 +65,44 @@ func TestDummyClient_Stop_ReturnsNil(t *testing.T) { t.Errorf("expected Stop to return nil, got %v", err) } } + +func TestDummyDarwinClient_WatchDevices_ClosesOnCancel(t *testing.T) { + client := newDummyDarwinClient(t) + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Errorf("expected channel to be closed after cancel") + } + case <-time.After(time.Second): + t.Fatal("timed out: channel not closed after context cancel") + } +} + +func TestDummyDarwinClient_WatchDevices_NoEventsEmitted(t *testing.T) { + client := newDummyDarwinClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case ev, ok := <-ch: + if ok { + t.Errorf("expected no events from stub, got %+v", ev) + } + case <-time.After(20 * time.Millisecond): + // good — no events emitted + } +} diff --git a/internal/midi/midilinux/client_dummy.go b/internal/midi/midilinux/client_dummy.go index 5646db7..41ca754 100644 --- a/internal/midi/midilinux/client_dummy.go +++ b/internal/midi/midilinux/client_dummy.go @@ -10,6 +10,8 @@ import ( "github.com/leandrodaf/midi/v2/sdk/contracts" ) +// dummyMIDIClient is the no-op ClientMIDI used on non-Linux systems when the +// midilinux package is selected by the build system. type dummyMIDIClient struct { logger contracts.Logger } @@ -38,3 +40,11 @@ func (m *dummyMIDIClient) Stop() error { m.logger.Warn("Stop called on dummy MIDI client") return nil } + +// WatchDevices returns a channel that is closed when ctx is cancelled. +// No device events are ever emitted — this is a no-op stub. +func (m *dummyMIDIClient) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + ch := make(chan contracts.DeviceEvent) + go func() { <-ctx.Done(); close(ch) }() + return ch, nil +} diff --git a/internal/midi/midilinux/client_dummy_test.go b/internal/midi/midilinux/client_dummy_test.go index 6c4834f..53840c6 100644 --- a/internal/midi/midilinux/client_dummy_test.go +++ b/internal/midi/midilinux/client_dummy_test.go @@ -7,6 +7,7 @@ import ( "context" "io" "testing" + "time" "github.com/leandrodaf/midi/v2/internal/logger" "github.com/leandrodaf/midi/v2/internal/midi/midilinux" @@ -59,3 +60,44 @@ func TestDummyLinuxClient_Stop_ReturnsNil(t *testing.T) { t.Errorf("expected Stop to return nil, got %v", err) } } + +func TestDummyLinuxClient_WatchDevices_ClosesOnCancel(t *testing.T) { + client := newDummyLinuxClient(t) + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Errorf("expected channel to be closed after cancel") + } + case <-time.After(time.Second): + t.Fatal("timed out: channel not closed after context cancel") + } +} + +func TestDummyLinuxClient_WatchDevices_NoEventsEmitted(t *testing.T) { + client := newDummyLinuxClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case ev, ok := <-ch: + if ok { + t.Errorf("expected no events from stub, got %+v", ev) + } + case <-time.After(20 * time.Millisecond): + // good — no events emitted + } +} diff --git a/internal/midi/midilinux/client_linux_cgo.go b/internal/midi/midilinux/client_linux_cgo.go index 77db4b5..fba8acb 100644 --- a/internal/midi/midilinux/client_linux_cgo.go +++ b/internal/midi/midilinux/client_linux_cgo.go @@ -42,6 +42,9 @@ type ClientMid struct { cancelPipeW int } +// NewMIDIClient creates a Linux ALSA raw-MIDI client. No hardware is opened +// at construction time; call SelectDevice to choose an input device and +// StartCapture to begin receiving events. func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, error) { options.Logger.Info("MIDI client created for Linux (ALSA)") return &ClientMid{ @@ -53,6 +56,8 @@ func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, erro }, nil } +// ListDevices enumerates ALSA raw-MIDI input devices and caches the result +// so that SelectDevice can map a stable integer index to a hardware address. func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { devs, err := alsa.EnumerateInputs() if err != nil { @@ -78,6 +83,8 @@ func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { return result, nil } +// SelectDevice records the hardware address of the ALSA device at index +// deviceID. If capture is already running it is stopped first. func (m *ClientMid) SelectDevice(deviceID int) error { m.mu.Lock() defer m.mu.Unlock() @@ -111,6 +118,7 @@ func (m *ClientMid) SelectDevice(deviceID int) error { return nil } +// closeOutCh closes the output channel exactly once. func (m *ClientMid) closeOutCh() { m.closeChOnce.Do(func() { if m.outCh != nil { @@ -143,6 +151,10 @@ func (m *ClientMid) Stop() error { return nil } +// StartCapture opens the previously selected ALSA device, creates a cancel +// pipe pair for cooperative shutdown, and launches a goroutine that calls +// readLoop to parse incoming MIDI bytes. The returned channel is closed when +// ctx is cancelled or Stop is called. func (m *ClientMid) StartCapture(ctx context.Context) (<-chan contracts.MIDI, error) { if err := m.Stop(); err != nil { return nil, err @@ -273,13 +285,72 @@ func (m *ClientMid) readLoop(raw *alsa.RawMIDI, midifd, cancelfd int) { } } -// midiParser is a byte-level MIDI stream parser that supports running status. +// WatchDevices returns a channel that emits a DeviceEvent whenever a MIDI +// device is connected or disconnected. On Linux, this is implemented by +// polling ListDevices every 2 seconds and diffing against the previous list. +func (m *ClientMid) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + evCh := make(chan contracts.DeviceEvent, 16) + + prev, _ := m.ListDevices() + + go func() { + defer close(evCh) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + curr, _ := m.ListDevices() + diffDevices(prev, curr, evCh) + prev = curr + } + } + }() + + return evCh, nil +} + +// diffDevices compares two device lists and sends DeviceAdded / DeviceRemoved +// events to evCh for each difference. +func diffDevices(prev, curr []contracts.DeviceInfo, evCh chan<- contracts.DeviceEvent) { + for _, d := range curr { + if !containsDevice(prev, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceAdded, Device: d}: + default: + } + } + } + for _, d := range prev { + if !containsDevice(curr, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceRemoved, Device: d}: + default: + } + } + } +} + +func containsDevice(list []contracts.DeviceInfo, d contracts.DeviceInfo) bool { + for _, item := range list { + if item.Name == d.Name && item.Manufacturer == d.Manufacturer { + return true + } + } + return false +} +// midiParser reassembles raw MIDI bytes into 3-byte channel-voice messages +// using the running-status rule. SysEx and real-time messages are discarded. type midiParser struct { status byte data [2]byte dataPos int } +// feed processes one raw MIDI byte. It returns (cmd, note, vel, true) when a +// complete 3-byte channel-voice message has been assembled, otherwise false. func (p *midiParser) feed(b byte) (cmd, note, vel byte, ok bool) { if b >= 0x80 { switch { diff --git a/internal/midi/midilinux/client_linux_cgo_test.go b/internal/midi/midilinux/client_linux_cgo_test.go new file mode 100644 index 0000000..abac88b --- /dev/null +++ b/internal/midi/midilinux/client_linux_cgo_test.go @@ -0,0 +1,115 @@ +//go:build linux && cgo +// +build linux,cgo + +package midilinux_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/leandrodaf/midi/v2/internal/logger" + "github.com/leandrodaf/midi/v2/internal/midi/midilinux" + "github.com/leandrodaf/midi/v2/sdk/contracts" +) + +func newLinuxClient(t *testing.T) contracts.ClientMIDI { + t.Helper() + client, err := midilinux.NewMIDIClient(&contracts.ClientOptions{ + Logger: logger.NewLoggerWithWriter(io.Discard), + ChannelBufferSize: 16, + }) + if err != nil { + t.Fatalf("NewMIDIClient returned unexpected error: %v", err) + } + return client +} + +func TestLinuxClient_WatchDevices_ClosesOnCancel(t *testing.T) { + client := newLinuxClient(t) + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Errorf("expected channel to be closed after context cancel") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out: channel not closed after context cancel") + } +} + +func TestLinuxClient_WatchDevices_ReturnsChannel(t *testing.T) { + client := newLinuxClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch == nil { + t.Fatal("expected non-nil channel") + } +} + +func TestDiffDevices_AddsNewDevice(t *testing.T) { + prev := []contracts.DeviceInfo{} + curr := []contracts.DeviceInfo{{Name: "Piano", Manufacturer: "Yamaha"}} + evCh := make(chan contracts.DeviceEvent, 8) + + midilinux.DiffDevicesExported(prev, curr, evCh) + + select { + case ev := <-evCh: + if ev.Type != contracts.DeviceAdded { + t.Errorf("expected DeviceAdded, got %v", ev.Type) + } + if ev.Device.Name != "Piano" { + t.Errorf("unexpected device name: %s", ev.Device.Name) + } + default: + t.Fatal("expected a DeviceAdded event") + } +} + +func TestDiffDevices_RemovesGoneDevice(t *testing.T) { + prev := []contracts.DeviceInfo{{Name: "Piano", Manufacturer: "Yamaha"}} + curr := []contracts.DeviceInfo{} + evCh := make(chan contracts.DeviceEvent, 8) + + midilinux.DiffDevicesExported(prev, curr, evCh) + + select { + case ev := <-evCh: + if ev.Type != contracts.DeviceRemoved { + t.Errorf("expected DeviceRemoved, got %v", ev.Type) + } + default: + t.Fatal("expected a DeviceRemoved event") + } +} + +func TestDiffDevices_NoChangeNoEvents(t *testing.T) { + dev := contracts.DeviceInfo{Name: "Piano", Manufacturer: "Yamaha"} + prev := []contracts.DeviceInfo{dev} + curr := []contracts.DeviceInfo{dev} + evCh := make(chan contracts.DeviceEvent, 8) + + midilinux.DiffDevicesExported(prev, curr, evCh) + + select { + case ev := <-evCh: + t.Errorf("expected no events, got %+v", ev) + default: + // good + } +} diff --git a/internal/midi/midilinux/client_linux_nocgo.go b/internal/midi/midilinux/client_linux_nocgo.go index d85d3f2..577e930 100644 --- a/internal/midi/midilinux/client_linux_nocgo.go +++ b/internal/midi/midilinux/client_linux_nocgo.go @@ -14,6 +14,8 @@ import ( // ErrCGORequired is returned on Linux when CGo is disabled at build time. var ErrCGORequired = errors.New("Linux MIDI requires CGo: rebuild with CGO_ENABLED=1") +// stubClient is the no-op ClientMIDI used on Linux when CGo is disabled at +// build time (CGO_ENABLED=0). All operations return ErrCGORequired. type stubClient struct { logger contracts.Logger } @@ -36,3 +38,11 @@ func (s *stubClient) StartCapture(_ context.Context) (<-chan contracts.MIDI, err } func (s *stubClient) Stop() error { return nil } + +// WatchDevices returns a channel that is closed when ctx is cancelled. +// No device events are ever emitted because CGo is required for ALSA. +func (s *stubClient) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + ch := make(chan contracts.DeviceEvent) + go func() { <-ctx.Done(); close(ch) }() + return ch, nil +} diff --git a/internal/midi/midilinux/export_test.go b/internal/midi/midilinux/export_test.go new file mode 100644 index 0000000..74a422e --- /dev/null +++ b/internal/midi/midilinux/export_test.go @@ -0,0 +1,12 @@ +//go:build linux && cgo +// +build linux,cgo + +// Package midilinux — test-only exports for white-box testing of internal helpers. +package midilinux + +import "github.com/leandrodaf/midi/v2/sdk/contracts" + +// DiffDevicesExported is a test-only shim that calls the unexported diffDevices helper. +func DiffDevicesExported(prev, curr []contracts.DeviceInfo, evCh chan<- contracts.DeviceEvent) { + diffDevices(prev, curr, evCh) +} diff --git a/internal/midi/midiwindows/client_dummy.go b/internal/midi/midiwindows/client_dummy.go index 373d2b1..662d871 100644 --- a/internal/midi/midiwindows/client_dummy.go +++ b/internal/midi/midiwindows/client_dummy.go @@ -10,6 +10,8 @@ import ( "github.com/leandrodaf/midi/v2/sdk/contracts" ) +// dummyMIDIClient is the no-op ClientMIDI used on non-Windows systems when the +// midiwindows package is selected by the build system. type dummyMIDIClient struct { logger contracts.Logger } @@ -38,3 +40,11 @@ func (m *dummyMIDIClient) Stop() error { m.logger.Warn("Stop called on dummy MIDI client") return nil } + +// WatchDevices returns a channel that is closed when ctx is cancelled. +// No device events are ever emitted — this is a no-op stub. +func (m *dummyMIDIClient) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + ch := make(chan contracts.DeviceEvent) + go func() { <-ctx.Done(); close(ch) }() + return ch, nil +} diff --git a/internal/midi/midiwindows/client_dummy_test.go b/internal/midi/midiwindows/client_dummy_test.go index 836fd36..0d961ed 100644 --- a/internal/midi/midiwindows/client_dummy_test.go +++ b/internal/midi/midiwindows/client_dummy_test.go @@ -7,6 +7,7 @@ import ( "context" "io" "testing" + "time" "github.com/leandrodaf/midi/v2/internal/logger" "github.com/leandrodaf/midi/v2/internal/midi/midiwindows" @@ -64,3 +65,44 @@ func TestDummyClient_Stop_ReturnsNil(t *testing.T) { t.Errorf("expected Stop to return nil, got %v", err) } } + +func TestDummyWindowsClient_WatchDevices_ClosesOnCancel(t *testing.T) { + client := newDummyWindowsClient(t) + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Errorf("expected channel to be closed after cancel") + } + case <-time.After(time.Second): + t.Fatal("timed out: channel not closed after context cancel") + } +} + +func TestDummyWindowsClient_WatchDevices_NoEventsEmitted(t *testing.T) { + client := newDummyWindowsClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case ev, ok := <-ch: + if ok { + t.Errorf("expected no events from stub, got %+v", ev) + } + case <-time.After(20 * time.Millisecond): + // good — no events emitted + } +} diff --git a/internal/midi/midiwindows/client_windows.go b/internal/midi/midiwindows/client_windows.go index 8773309..34bbd04 100644 --- a/internal/midi/midiwindows/client_windows.go +++ b/internal/midi/midiwindows/client_windows.go @@ -25,6 +25,7 @@ var ( ErrCloseDeviceFailed = errors.New("failed to close MIDI device") ) +// HMIDIIN is a Windows MIDI input device handle (wraps windows.Handle). type HMIDIIN windows.Handle const ( @@ -41,6 +42,7 @@ const ( MIM_MOREDATA = 0x3CC ) +// midiInCaps mirrors the Windows MIDIINCAPSW structure used by midiInGetDevCapsW. type midiInCaps struct { wMid uint16 wPid uint16 @@ -49,6 +51,10 @@ type midiInCaps struct { dwSupport uint32 } +// ClientMid is the Windows winmm implementation of contracts.ClientMIDI. +// It uses midiInOpen/midiInStart/midiInStop/midiInClose from winmm.dll and +// receives MIDI events via a windows.NewCallback function registered at +// SelectDevice time. type ClientMid struct { logger contracts.Logger eventChannel atomic.Value @@ -73,6 +79,8 @@ var ( procMidiInClose = winmm.NewProc("midiInClose") ) +// NewMIDIClient creates a Windows winmm MIDI client. No hardware handle is +// opened at construction time; call SelectDevice to open a device. func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, error) { options.Logger.Info("MIDI client created for Windows") return &ClientMid{ @@ -83,6 +91,7 @@ func NewMIDIClient(options *contracts.ClientOptions) (contracts.ClientMIDI, erro }, nil } +// ListDevices enumerates MIDI input devices via midiInGetNumDevs / midiInGetDevCapsW. func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { r0, _, _ := procMidiInGetNumDevs.Call() if uint32(r0) == 0 { @@ -108,6 +117,9 @@ func (m *ClientMid) ListDevices() ([]contracts.DeviceInfo, error) { return devices, nil } +// SelectDevice opens the winmm MIDI input device at index deviceID and +// registers midiInCallback as the low-level callback. Any previous device +// handle is closed first. func (m *ClientMid) SelectDevice(deviceID int) error { m.mu.Lock() defer m.mu.Unlock() @@ -134,6 +146,7 @@ func (m *ClientMid) SelectDevice(deviceID int) error { return nil } +// closeOutCh closes the output channel exactly once. func (m *ClientMid) closeOutCh() { m.closeChOnce.Do(func() { if m.outCh != nil { @@ -142,6 +155,9 @@ func (m *ClientMid) closeOutCh() { }) } +// StartCapture calls midiInStart on the open device handle and returns a +// buffered channel that receives MIDI events via midiInCallback. The channel +// is closed when ctx is cancelled or Stop is called. func (m *ClientMid) StartCapture(ctx context.Context) (<-chan contracts.MIDI, error) { m.mu.Lock() defer m.mu.Unlock() @@ -178,6 +194,10 @@ func (m *ClientMid) StartCapture(ctx context.Context) (<-chan contracts.MIDI, er return ch, nil } +// midiInCallback is the low-level winmm callback registered by SelectDevice. +// It runs on a dedicated OS thread created by the winmm driver; it must not +// call any blocking winmm functions. MIM_DATA events are decoded and forwarded +// to the output channel. func midiInCallback(hMidiIn uintptr, wMsg uint32, dwInstance uintptr, dwParam1 uintptr, dwParam2 uintptr) uintptr { m := (*ClientMid)(unsafe.Pointer(dwInstance)) @@ -232,6 +252,8 @@ func midiInCallback(hMidiIn uintptr, wMsg uint32, dwInstance uintptr, dwParam1 u return 0 } +// Stop calls midiInStop and midiInClose, closes the output channel, and resets +// the device handle. Safe to call concurrently and when no device is selected. func (m *ClientMid) Stop() error { m.mu.Lock() defer m.mu.Unlock() @@ -247,6 +269,65 @@ func (m *ClientMid) Stop() error { return nil } +// WatchDevices returns a channel that emits a DeviceEvent whenever a MIDI +// device is connected or disconnected. On Windows, this is implemented by +// polling midiInGetNumDevs every 2 seconds and diffing against the previous list. +func (m *ClientMid) WatchDevices(ctx context.Context) (<-chan contracts.DeviceEvent, error) { + evCh := make(chan contracts.DeviceEvent, 16) + + prev, _ := m.ListDevices() + + go func() { + defer close(evCh) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + curr, _ := m.ListDevices() + diffDevices(prev, curr, evCh) + prev = curr + } + } + }() + + return evCh, nil +} + +// diffDevices compares two device lists and sends DeviceAdded / DeviceRemoved +// events to evCh for each difference. +func diffDevices(prev, curr []contracts.DeviceInfo, evCh chan<- contracts.DeviceEvent) { + for _, d := range curr { + if !containsDevice(prev, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceAdded, Device: d}: + default: + } + } + } + for _, d := range prev { + if !containsDevice(curr, d) { + select { + case evCh <- contracts.DeviceEvent{Type: contracts.DeviceRemoved, Device: d}: + default: + } + } + } +} + +func containsDevice(list []contracts.DeviceInfo, d contracts.DeviceInfo) bool { + for _, item := range list { + if item.Name == d.Name && item.Manufacturer == d.Manufacturer { + return true + } + } + return false +} + +// stopCapture stops MIDI input and closes the device handle. +// Must be called with m.mu held. func (m *ClientMid) stopCapture() error { if m.handle == 0 { return ErrInvalidDeviceHandle diff --git a/sdk/contracts/midi.go b/sdk/contracts/midi.go index 85a4094..4adae54 100644 --- a/sdk/contracts/midi.go +++ b/sdk/contracts/midi.go @@ -10,6 +10,22 @@ type MIDI struct { Velocity byte } +// DeviceEventType describes what happened to a MIDI device. +type DeviceEventType int + +const ( + // DeviceAdded is sent when a new MIDI device becomes available. + DeviceAdded DeviceEventType = iota + // DeviceRemoved is sent when a MIDI device is disconnected or deactivated. + DeviceRemoved +) + +// DeviceEvent is emitted by WatchDevices when the set of MIDI devices changes. +type DeviceEvent struct { + Type DeviceEventType + Device DeviceInfo +} + // ClientMIDI defines the interface for MIDI client operations. type ClientMIDI interface { // Stop halts MIDI event capture and releases resources. @@ -22,4 +38,9 @@ type ClientMIDI interface { // that receives events. The channel is closed when the context is cancelled // or Stop() is called. The channel buffer size is controlled by WithChannelBufferSize. StartCapture(ctx context.Context) (<-chan MIDI, error) + // WatchDevices returns a channel that emits DeviceEvent values whenever a + // MIDI device is connected or disconnected. The channel is closed when ctx + // is cancelled. Implementations may use OS-level notifications (macOS + // CoreMIDI) or periodic polling (Linux, Windows). + WatchDevices(ctx context.Context) (<-chan DeviceEvent, error) } diff --git a/sdk/contracts/mock.go b/sdk/contracts/mock.go index 98c6cfd..58de94d 100644 --- a/sdk/contracts/mock.go +++ b/sdk/contracts/mock.go @@ -3,16 +3,24 @@ package contracts import "context" // MockMIDIClient is a configurable ClientMIDI mock for tests. +// Set the *Func fields to override behaviour; the *Calls fields count invocations. +// Zero-value Func fields fall back to safe no-op defaults (WatchDevices closes +// the channel when ctx is cancelled; all others return nil/zero values). type MockMIDIClient struct { StartCaptureFunc func(ctx context.Context) (<-chan MIDI, error) StopFunc func() error ListDevicesFunc func() ([]DeviceInfo, error) SelectDeviceFunc func(deviceID int) error + // WatchDevicesFunc is called by WatchDevices. When nil, the default + // implementation returns a channel that is closed when ctx is cancelled. + WatchDevicesFunc func(ctx context.Context) (<-chan DeviceEvent, error) StartCaptureCalls int StopCalls int ListDevicesCalls int SelectDeviceCalls int + // WatchDevicesCalls is incremented on every call to WatchDevices. + WatchDevicesCalls int } func (m *MockMIDIClient) StartCapture(ctx context.Context) (<-chan MIDI, error) { @@ -46,3 +54,19 @@ func (m *MockMIDIClient) SelectDevice(deviceID int) error { } return nil } + +// WatchDevices delegates to WatchDevicesFunc when set; otherwise it returns a +// channel that is closed when ctx is cancelled, matching the contract that +// callers must range over the channel until it is closed. +func (m *MockMIDIClient) WatchDevices(ctx context.Context) (<-chan DeviceEvent, error) { + m.WatchDevicesCalls++ + if m.WatchDevicesFunc != nil { + return m.WatchDevicesFunc(ctx) + } + ch := make(chan DeviceEvent) + go func() { + <-ctx.Done() + close(ch) + }() + return ch, nil +} diff --git a/sdk/contracts/mock_test.go b/sdk/contracts/mock_test.go index 82ecc77..47e53ab 100644 --- a/sdk/contracts/mock_test.go +++ b/sdk/contracts/mock_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "github.com/leandrodaf/midi/v2/sdk/contracts" ) @@ -139,3 +140,78 @@ func TestMockMIDIClient_CallCountersIncrement(t *testing.T) { t.Errorf("expected SelectDeviceCalls to be 3, got %d", client.SelectDeviceCalls) } } + +func TestMockMIDIClient_WatchDevices_UsesFunc(t *testing.T) { + expected := make(chan contracts.DeviceEvent, 1) + client := &contracts.MockMIDIClient{ + WatchDevicesFunc: func(_ context.Context) (<-chan contracts.DeviceEvent, error) { + return expected, nil + }, + } + + ch, err := client.WatchDevices(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch != expected { + t.Errorf("expected WatchDevicesFunc channel to be returned") + } + if client.WatchDevicesCalls != 1 { + t.Errorf("expected WatchDevicesCalls=1, got %d", client.WatchDevicesCalls) + } +} + +func TestMockMIDIClient_WatchDevices_FuncError(t *testing.T) { + want := errors.New("watch error") + client := &contracts.MockMIDIClient{ + WatchDevicesFunc: func(_ context.Context) (<-chan contracts.DeviceEvent, error) { + return nil, want + }, + } + + ch, err := client.WatchDevices(context.Background()) + if ch != nil { + t.Errorf("expected nil channel on error") + } + if err != want { + t.Errorf("expected configured error, got %v", err) + } +} + +func TestMockMIDIClient_WatchDevices_DefaultClosesOnCancel(t *testing.T) { + client := &contracts.MockMIDIClient{} + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := client.WatchDevices(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Errorf("expected channel to be closed after context cancel") + } + case <-time.After(time.Second): + t.Fatal("timed out: channel not closed after context cancel") + } + if client.WatchDevicesCalls != 1 { + t.Errorf("expected WatchDevicesCalls=1, got %d", client.WatchDevicesCalls) + } +} + +func TestMockMIDIClient_WatchDevices_IncrementsCalls(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + client := &contracts.MockMIDIClient{} + + for i := range 3 { + _, _ = client.WatchDevices(ctx) + _ = i + } + if client.WatchDevicesCalls != 3 { + t.Errorf("expected WatchDevicesCalls=3, got %d", client.WatchDevicesCalls) + } +}