diff --git a/client.go b/client.go index 39dabfc..0e6e14d 100644 --- a/client.go +++ b/client.go @@ -90,3 +90,30 @@ func (c *Client) SetReadDeadline(t time.Time) error { func (c *Client) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) } + +// ReloadRegulatoryDatabase reloads the wireless regulatory database. +// +// This can be used if cfg80211 was built into the kernel and the wireless regulatory database +// was not available during early boot. +// +// See https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +func (c *Client) ReloadRegulatoryDatabase() error { + return c.c.ReloadRegulatoryDatabase() +} + +// GetRegulatoryRegion gets the system-wide regulatory domain. +// See https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +func (c *Client) GetRegulatoryDomain() (*RegulatoryDomain, error) { + return c.c.GetRegulatoryDomain() +} + +// SetRegulatoryRegion sets the system-wide regulatory region used by all nl80211 devices. +// You may need to call [Client.ReloadRegulatoryDatabase] first to ensure the region is updated correctly. +// +// region must be an ISO 3166-1 alpha-2 country code (e.g. "GB" or "US"). +// hint should almost always be [RegulatoryHintUser]. +// +// See https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +func (c *Client) SetRegulatoryRegion(region string, hint RegulatoryHint) error { + return c.c.SetRegulatoryRegion(region, hint) +} diff --git a/client_linux.go b/client_linux.go index 9be6a41..4d4619f 100644 --- a/client_linux.go +++ b/client_linux.go @@ -372,6 +372,76 @@ func (c *client) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) } +// ReloadRegulatoryDatabase reloads the wireless regulatory database. +// +// This can be used if cfg80211 was built into the kernel and the wireless regulatory database +// was not available during early boot. +// +// See https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +func (c *client) ReloadRegulatoryDatabase() error { + _, err := c.get( + unix.NL80211_CMD_RELOAD_REGDB, + netlink.Acknowledge, + nil, + nil, + ) + + return err +} + +// GetRegulatoryDomain gets the system-wide regulatory region used by all nl80211 devices. +// See +// - https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +// - https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/wireless/nl80211.c?h=a55f7f5f29b32c2c53cc291899cf9b0c25a07f7c#n9920 +func (c *client) GetRegulatoryDomain() (*RegulatoryDomain, error) { + msgs, err := c.get( + unix.NL80211_CMD_GET_REG, + netlink.Request, + nil, + nil, + ) + if err != nil { + return nil, err + } + + // We expect one message which represents the global regulatory domain. + if len(msgs) == 0 { + return nil, os.ErrNotExist + } + + attrs, err := netlink.UnmarshalAttributes(msgs[0].Data) + if err != nil { + return nil, err + } + + var domain RegulatoryDomain + if err := domain.parseAttributes(attrs); err != nil { + return nil, err + } + + return &domain, nil +} + +// SetRegulatoryRegion sets the system-wide regulatory region used by all nl80211 devices. +// You may need to call [client.ReloadRegulatoryDatabase] first to ensure the region is updated correctly. +// +// region must be an ISO 3166-1 alpha-2 country code (e.g. "GB" or "US"). +// +// See https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/wireless-regdb.html +func (c *client) SetRegulatoryRegion(region string, hint RegulatoryHint) error { + _, err := c.get( + unix.NL80211_CMD_REQ_SET_REG, + netlink.Acknowledge, + nil, + func(ae *netlink.AttributeEncoder) { + ae.String(unix.NL80211_ATTR_REG_ALPHA2, region) + ae.Uint32(unix.NL80211_ATTR_USER_REG_HINT_TYPE, uint32(hint)) + }, + ) + + return err +} + // get performs a request/response interaction with nl80211. func (c *client) get( cmd uint8, @@ -1063,3 +1133,15 @@ found: return (features[feature/8]&(1<<(feature%8)) != 0), nil } + +// parseAttributes parses netlink attributes into a RegulatoryDomain's fields. +func (d *RegulatoryDomain) parseAttributes(attrs []netlink.Attribute) error { + for _, a := range attrs { + switch a.Type { + case unix.NL80211_ATTR_REG_ALPHA2: + d.Region = nlenc.String(a.Data) + } + } + + return nil +} diff --git a/client_linux_integration_test.go b/client_linux_integration_test.go index 88c9ded..3991e10 100644 --- a/client_linux_integration_test.go +++ b/client_linux_integration_test.go @@ -150,3 +150,52 @@ func TestClient_AccessPoints(t *testing.T) { } } + +func TestClient_ReloadRegulatoryDatabase(t *testing.T) { + if os.Geteuid() != 0 { + t.Skipf("skipping, must be run as root") + } + + c, err := wifi.New() + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + err = c.ReloadRegulatoryDatabase() + if err != nil { + t.Fatalf("failed to reload the regulatory database") + } +} + +func TestClient_SetRegulatoryRegion(t *testing.T) { + if os.Geteuid() != 0 { + t.Skipf("skipping, must be run as root") + } + + c, err := wifi.New() + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + err = c.SetRegulatoryRegion("00", wifi.RegulatoryHintUser) + if err != nil { + t.Fatalf("failed to set the regulatory domain") + } +} + +func TestClient_GetRegulatoryDomain(t *testing.T) { + c, err := wifi.New() + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + domain, err := c.GetRegulatoryDomain() + if err != nil { + t.Fatalf("failed to retrieve regulatory domain") + } + + // At a minimum, the region will be 00. + if len(domain.Region) != 2 { + t.Fatalf("failed to retrieve regulatory domain region") + } +} diff --git a/client_linux_test.go b/client_linux_test.go index b2c83a9..920e637 100644 --- a/client_linux_test.go +++ b/client_linux_test.go @@ -563,6 +563,15 @@ func (s *SurveyInfo) attributes() []netlink.Attribute { } } +func (d *RegulatoryDomain) attributes() []netlink.Attribute { + return []netlink.Attribute{ + { + Type: unix.NL80211_ATTR_REG_ALPHA2, + Data: nlenc.Bytes(d.Region), + }, + } +} + func bitrateAttr(bitrate int) uint32 { return uint32(bitrate / 100 / 1000) } @@ -586,6 +595,9 @@ func mustMessages(t *testing.T, command uint8, want any) genltest.Func { for _, x := range xs { as = append(as, x) } + case *RegulatoryDomain: + as = append(as, xs) + default: t.Fatalf("cannot make messages for type: %T", xs) } @@ -1074,3 +1086,36 @@ func assertRSNError(t *testing.T, input []byte, expected *RSNInfo, errMsg string t.Errorf("decodeRSN() group cipher = %v, want %v", got.GroupCipher, expected.GroupCipher) } } + +func TestLinux_GetRegulatoryDomain_NoMessages(t *testing.T) { + c := testClient(t, func(_ genetlink.Message, _ netlink.Message) ([]genetlink.Message, error) { + // No messages + return nil, io.EOF + }) + + _, err := c.GetRegulatoryDomain() + if !os.IsNotExist(err) { + t.Fatalf("expected is not exist, got: %v", err) + } +} + +func TestLinux_GetRegulatoryDomain_OK(t *testing.T) { + want := &RegulatoryDomain{ + Region: "GB", + } + + const flags = netlink.Request + + c := testClient(t, genltest.CheckRequest(familyID, unix.NL80211_CMD_GET_REG, flags, + mustMessages(t, unix.NL80211_CMD_GET_REG, want), + )) + + got, err := c.GetRegulatoryDomain() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("unexpected region (-want +got):\n%s", diff) + } +} diff --git a/client_others.go b/client_others.go index 6f10434..912194b 100644 --- a/client_others.go +++ b/client_others.go @@ -18,16 +18,19 @@ type client struct{} func newClient() (*client, error) { return nil, errUnimplemented } -func (*client) Close() error { return errUnimplemented } -func (*client) Interfaces() ([]*Interface, error) { return nil, errUnimplemented } -func (*client) BSS(_ *Interface) (*BSS, error) { return nil, errUnimplemented } -func (client) AccessPoints(ifi *Interface) ([]*BSS, error) { return nil, errUnimplemented } -func (*client) StationInfo(_ *Interface) ([]*StationInfo, error) { return nil, errUnimplemented } -func (*client) SurveyInfo(_ *Interface) ([]*SurveyInfo, error) { return nil, errUnimplemented } -func (*client) Scan(ctx context.Context, ifi *Interface) error { return errUnimplemented } -func (*client) Connect(_ *Interface, _ string) error { return errUnimplemented } -func (*client) Disconnect(_ *Interface) error { return errUnimplemented } -func (*client) ConnectWPAPSK(_ *Interface, _, _ string) error { return errUnimplemented } -func (*client) SetDeadline(t time.Time) error { return errUnimplemented } -func (*client) SetReadDeadline(t time.Time) error { return errUnimplemented } -func (*client) SetWriteDeadline(t time.Time) error { return errUnimplemented } +func (*client) Close() error { return errUnimplemented } +func (*client) Interfaces() ([]*Interface, error) { return nil, errUnimplemented } +func (*client) BSS(_ *Interface) (*BSS, error) { return nil, errUnimplemented } +func (client) AccessPoints(ifi *Interface) ([]*BSS, error) { return nil, errUnimplemented } +func (*client) StationInfo(_ *Interface) ([]*StationInfo, error) { return nil, errUnimplemented } +func (*client) SurveyInfo(_ *Interface) ([]*SurveyInfo, error) { return nil, errUnimplemented } +func (*client) Scan(ctx context.Context, ifi *Interface) error { return errUnimplemented } +func (*client) Connect(_ *Interface, _ string) error { return errUnimplemented } +func (*client) Disconnect(_ *Interface) error { return errUnimplemented } +func (*client) ConnectWPAPSK(_ *Interface, _, _ string) error { return errUnimplemented } +func (*client) SetDeadline(t time.Time) error { return errUnimplemented } +func (*client) SetReadDeadline(t time.Time) error { return errUnimplemented } +func (*client) SetWriteDeadline(t time.Time) error { return errUnimplemented } +func (*client) GetRegulatoryDomain() (*RegulatoryDomain, error) { return nil, errUnimplemented } +func (*client) SetRegulatoryRegion(_ string, _ RegulatoryHint) error { return errUnimplemented } +func (*client) ReloadRegulatoryDatabase() error { return errUnimplemented } diff --git a/wifi.go b/wifi.go index c3225e9..059491b 100644 --- a/wifi.go +++ b/wifi.go @@ -615,3 +615,20 @@ func (r RSNInfo) String() string { "RSN v%d Group:%s Pairwise:%v AKM:%v", r.Version, r.GroupCipher.String(), pairwiseNames, akmNames) } + +// RegulatoryDomain contains information about the regulatory domain the device or system operates in. +type RegulatoryDomain struct { + Region string // ISO 3166-1 alpha-2 country code +} + +// RegulatoryHint describes why the regulatory region is being set. +// See: +// - https://wireless.docs.kernel.org/en/latest/en/developers/regulatory/processing_rules.html#cellular-base-station-regulatory-hints +// - https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/nl80211.h?h=a55f7f5f29b32c2c53cc291899cf9b0c25a07f7c#n4803 +type RegulatoryHint uint32 + +const ( + RegulatoryHintUser RegulatoryHint = 0x0 // Set by the userspace (default) + RegulatoryHintCellBase RegulatoryHint = 0x1 // Set based on cellular network information + RegulatoryHintIndoor RegulatoryHint = 0x2 // Set because the device is indoors +)