mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-09 02:09:55 +00:00
Compare commits
7 Commits
feat/byod-
...
windows-dn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f8b88471f | ||
|
|
f42b8aed90 | ||
|
|
0415137acd | ||
|
|
7fd16666e3 | ||
|
|
0571eeaba0 | ||
|
|
6a201d12b5 | ||
|
|
4810e79a00 |
@@ -92,6 +92,9 @@ linters:
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
@@ -584,9 +583,6 @@ func isSensitiveEnvVar(key string) bool {
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil {
|
||||
configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String()))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface))
|
||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort))
|
||||
if g.internalConfig.NetworkMonitor != nil {
|
||||
@@ -611,12 +607,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
if g.internalConfig.DisableSSHAuth != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
||||
}
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
@@ -643,7 +633,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addProf() (err error) {
|
||||
|
||||
@@ -5,21 +5,16 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -476,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
@@ -494,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
}
|
||||
|
||||
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
||||
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
||||
// excluded. When a new field is added to Config, this test fails until the
|
||||
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
||||
// the excluded set with a justification.
|
||||
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
excluded := map[string]string{
|
||||
"PrivateKey": "sensitive: WireGuard private key",
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
aURL, _ := url.Parse("https://admin.example.com:443")
|
||||
bTrue := true
|
||||
iVal := 42
|
||||
cfg := &profilemanager.Config{
|
||||
PrivateKey: "priv",
|
||||
PreSharedKey: "psk",
|
||||
ManagementURL: mURL,
|
||||
AdminURL: aURL,
|
||||
WgIface: "wt0",
|
||||
WgPort: 51820,
|
||||
NetworkMonitor: &bTrue,
|
||||
IFaceBlackList: []string{"eth0"},
|
||||
DisableIPv6Discovery: true,
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
EnableSSHRemotePortForwarding: &bTrue,
|
||||
DisableSSHAuth: &bTrue,
|
||||
SSHJWTCacheTTL: &iVal,
|
||||
DisableClientRoutes: true,
|
||||
DisableServerRoutes: true,
|
||||
DisableDNS: true,
|
||||
DisableFirewall: true,
|
||||
BlockLANAccess: true,
|
||||
BlockInbound: true,
|
||||
DisableNotifications: &bTrue,
|
||||
DNSLabels: domain.List{},
|
||||
SSHKey: "sshkey",
|
||||
NATExternalIPs: []string{"1.2.3.4"},
|
||||
CustomDNSAddress: "1.1.1.1:53",
|
||||
DisableAutoConnect: true,
|
||||
DNSRouteInterval: 5 * time.Second,
|
||||
ClientCertPath: "/tmp/cert",
|
||||
ClientCertKeyPath: "/tmp/key",
|
||||
LazyConnectionEnabled: true,
|
||||
MTU: 1280,
|
||||
}
|
||||
|
||||
for _, anonymize := range []bool{false, true} {
|
||||
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: newAnonymizerForTest(),
|
||||
internalConfig: cfg,
|
||||
anonymize: anonymize,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
g.addCommonConfigFields(&sb)
|
||||
rendered := sb.String() + renderAddConfigSpecific(g)
|
||||
|
||||
val := reflect.ValueOf(cfg).Elem()
|
||||
typ := val.Type()
|
||||
var missing []string
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
name := typ.Field(i).Name
|
||||
if _, ok := excluded[name]; ok {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(rendered, name+":") {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
||||
"Either render the field in addCommonConfigFields/addConfig, "+
|
||||
"or add it to the excluded map with a justification.", missing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
||||
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
||||
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
||||
// production shape without needing to write an actual zip.
|
||||
func renderAddConfigSpecific(g *BundleGenerator) string {
|
||||
var sb strings.Builder
|
||||
if g.anonymize {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
||||
}
|
||||
} else {
|
||||
if g.internalConfig.ManagementURL != nil {
|
||||
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
||||
}
|
||||
if g.internalConfig.AdminURL != nil {
|
||||
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
||||
}
|
||||
sb.WriteString("NATExternalIPs: x\n")
|
||||
if g.internalConfig.CustomDNSAddress != "" {
|
||||
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func newAnonymizerForTest() *anonymize.Anonymizer {
|
||||
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
}
|
||||
|
||||
63
client/internal/dns/dnsfw/config.go
Normal file
63
client/internal/dns/dnsfw/config.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
||||
// Empty disables the firewall.
|
||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
||||
// and the netbird daemon. Default mode also permits anything on the
|
||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
||||
)
|
||||
|
||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
||||
var defaultBlockedPorts = []uint16{53, 853}
|
||||
|
||||
// blockedPorts returns the effective port list, honoring env overrides.
|
||||
// A nil return means the firewall should not be installed.
|
||||
func blockedPorts() []uint16 {
|
||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
||||
return nil
|
||||
}
|
||||
|
||||
override, ok := os.LookupEnv(EnvPorts)
|
||||
if !ok {
|
||||
return defaultBlockedPorts
|
||||
}
|
||||
|
||||
var ports []uint16
|
||||
for _, raw := range strings.Split(override, ",") {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.ParseUint(raw, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
||||
continue
|
||||
}
|
||||
if port == 0 {
|
||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
||||
continue
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
||||
return nil
|
||||
}
|
||||
return ports
|
||||
}
|
||||
39
client/internal/dns/dnsfw/config_test.go
Normal file
39
client/internal/dns/dnsfw/config_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlockedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disable string
|
||||
ports string
|
||||
setPorts bool
|
||||
want []uint16
|
||||
}{
|
||||
{name: "default", want: defaultBlockedPorts},
|
||||
{name: "disabled", disable: "true", want: nil},
|
||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvDisable, tc.disable)
|
||||
if tc.setPorts {
|
||||
t.Setenv(EnvPorts, tc.ports)
|
||||
}
|
||||
got := blockedPorts()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
||||
// outside netbird cannot bypass the configured DNS path.
|
||||
//
|
||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
||||
// a no-op manager.
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
||||
// to call multiple times.
|
||||
type Manager interface {
|
||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
||||
Disable() error
|
||||
}
|
||||
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type noopManager struct{}
|
||||
|
||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
||||
func (noopManager) Disable() error { return nil }
|
||||
|
||||
// New returns a no-op manager on non-Windows platforms.
|
||||
func New() Manager {
|
||||
return noopManager{}
|
||||
}
|
||||
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
||||
)
|
||||
|
||||
type windowsManager struct {
|
||||
mu sync.Mutex
|
||||
// session is the WFP engine handle. Zero when disabled.
|
||||
session uintptr
|
||||
}
|
||||
|
||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ports := blockedPorts()
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.session != 0 {
|
||||
if err := m.disableLocked(); err != nil {
|
||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
strict := strictMode()
|
||||
|
||||
luid, err := luidFromGUID(ifaceGUID)
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
||||
}
|
||||
|
||||
cfg := installConfig{
|
||||
tunLUID: luid,
|
||||
daemonExe: exe,
|
||||
blockedPorts: ports,
|
||||
strict: strict,
|
||||
virtualDNSIP: virtualDNSIP,
|
||||
}
|
||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
||||
session, installErr := installFilters(cfg)
|
||||
if session == 0 {
|
||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
||||
}
|
||||
|
||||
if installErr != nil && strict {
|
||||
_ = closeSession(session)
|
||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
||||
if installErr != nil {
|
||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsManager) Disable() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.disableLocked()
|
||||
}
|
||||
|
||||
func (m *windowsManager) disableLocked() error {
|
||||
if m.session == 0 {
|
||||
return nil
|
||||
}
|
||||
session := m.session
|
||||
m.session = 0
|
||||
if err := closeSession(session); err != nil {
|
||||
return fmt.Errorf("close wfp session: %w", err)
|
||||
}
|
||||
log.Info("dns firewall removed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
||||
// error is logged and nil is returned.
|
||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
||||
if strict {
|
||||
return err
|
||||
}
|
||||
log.Errorf("dns firewall: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a Windows DNS firewall manager backed by WFP.
|
||||
func New() Manager {
|
||||
return &windowsManager{}
|
||||
}
|
||||
|
||||
// strictMode reports whether strict mode is enabled via env.
|
||||
func strictMode() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
||||
return v
|
||||
}
|
||||
|
||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse guid: %w", err)
|
||||
}
|
||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
if rc != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
||||
}
|
||||
return luid, nil
|
||||
}
|
||||
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrictMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
set bool
|
||||
want bool
|
||||
}{
|
||||
{name: "unset", want: false},
|
||||
{name: "true", val: "true", set: true, want: true},
|
||||
{name: "1", val: "1", set: true, want: true},
|
||||
{name: "false", val: "false", set: true, want: false},
|
||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvStrict, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(EnvStrict)
|
||||
}
|
||||
if got := strictMode(); got != tc.want {
|
||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
||||
m := &windowsManager{}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
||||
}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session should remain zero, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
||||
t.Setenv(EnvDisable, "true")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
||||
t.Setenv(EnvPorts, "")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
||||
}
|
||||
}
|
||||
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
return &wtFwpmDisplayData0{
|
||||
name: namePtr,
|
||||
description: descriptionPtr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func filterWeight(weight uint8) wtFwpValue0 {
|
||||
return wtFwpValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(weight),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapErr(err error) error {
|
||||
var errno syscall.Errno
|
||||
if !errors.As(err, &errno) {
|
||||
return err
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
||||
}
|
||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
||||
}
|
||||
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
@@ -0,0 +1,249 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
||||
// follow the authorized outbound flow.
|
||||
|
||||
// permitTunInterface installs a permit filter for any traffic whose local
|
||||
// interface is the netbird tunnel.
|
||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT64,
|
||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
||||
}
|
||||
|
||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
||||
// dedicated binary.
|
||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
||||
appID, err := daemonAppID(daemonExe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
||||
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_BLOB_TYPE,
|
||||
value: uintptr(unsafe.Pointer(appID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
||||
}
|
||||
|
||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
||||
// permitTunInterface.
|
||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
||||
if !ip.IsValid() {
|
||||
return fmt.Errorf("invalid address")
|
||||
}
|
||||
|
||||
var addrCond wtFwpmFilterCondition0
|
||||
var layer windows.GUID
|
||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
||||
var v6 wtFwpByteArray16
|
||||
|
||||
if ip.Is4() {
|
||||
v4 := ip.As4()
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT32,
|
||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
||||
} else {
|
||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
||||
value: uintptr(unsafe.Pointer(&v6)),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
||||
}
|
||||
|
||||
conditions := [2]wtFwpmFilterCondition0{
|
||||
addrCond,
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
_ = v6
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
||||
// accumulated; partial coverage is preferred over zero coverage.
|
||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
||||
conditions := [3]wtFwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_UDP),
|
||||
},
|
||||
},
|
||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_TCP),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
||||
}
|
||||
|
||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
||||
// connect layers. v4 and v6 are installed independently: failure on one
|
||||
// layer does not abort the other, and the accumulated errors are returned.
|
||||
// Partial coverage is preferred over zero coverage.
|
||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
||||
layers := [...]struct {
|
||||
layer windows.GUID
|
||||
label string
|
||||
}{
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, l := range layers {
|
||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
continue
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = l.layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
177
client/internal/dns/dnsfw/session_windows.go
Normal file
177
client/internal/dns/dnsfw/session_windows.go
Normal file
@@ -0,0 +1,177 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
||||
* from wireguard-windows tunnel/firewall.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// installConfig is the input to installFilters.
|
||||
type installConfig struct {
|
||||
tunLUID uint64
|
||||
daemonExe string
|
||||
blockedPorts []uint16
|
||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
||||
strict bool
|
||||
virtualDNSIP netip.Addr
|
||||
}
|
||||
|
||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
||||
// for our session. Both are randomly generated per session.
|
||||
type baseObjects struct {
|
||||
provider windows.GUID
|
||||
filters windows.GUID
|
||||
}
|
||||
|
||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
||||
// firewall filters. Returns a zero session on hard failure (session create,
|
||||
// base objects); a non-zero session with a non-nil error is a partial install
|
||||
// (some per-filter installs failed) and is safe to close.
|
||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Dynamic session: kernel will clean up on process exit even
|
||||
// if we leave the handle dangling here.
|
||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(cfg.blockedPorts) == 0 {
|
||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
||||
}
|
||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
||||
}
|
||||
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
base, err := registerBaseObjects(session)
|
||||
if err != nil {
|
||||
_ = fwpmEngineClose0(session)
|
||||
return 0, fmt.Errorf("register base objects: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if cfg.strict {
|
||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
||||
}
|
||||
}
|
||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
||||
}
|
||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
||||
}
|
||||
|
||||
return session, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// closeSession tears down a WFP session previously opened by installFilters.
|
||||
// All filters owned by the session are removed.
|
||||
func closeSession(session uintptr) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if session == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := fwpmEngineClose0(session); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession() (uintptr, error) {
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
||||
if err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
session := wtFwpmSession0{
|
||||
displayData: *displayData,
|
||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
||||
}
|
||||
var handle uintptr
|
||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
||||
bo := &baseObjects{}
|
||||
var err error
|
||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
provider := wtFwpmProvider0{
|
||||
providerKey: bo.provider,
|
||||
displayData: *displayData,
|
||||
}
|
||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
sublayer := wtFwpmSublayer0{
|
||||
subLayerKey: bo.filters,
|
||||
displayData: *subDisplay,
|
||||
providerKey: &bo.provider,
|
||||
weight: ^uint16(0),
|
||||
}
|
||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return bo, nil
|
||||
}
|
||||
|
||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
var appID *wtFwpByteBlob
|
||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return appID, nil
|
||||
}
|
||||
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
||||
414
client/internal/dns/dnsfw/types_windows.go
Normal file
414
client/internal/dns/dnsfw/types_windows.go
Normal file
@@ -0,0 +1,414 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
||||
|
||||
wtFwpBitmapArray64_Size = 8
|
||||
|
||||
wtFwpByteArray16_Size = 16
|
||||
|
||||
wtFwpByteArray6_Size = 6
|
||||
|
||||
wtFwpmAction0_Size = 20
|
||||
wtFwpmAction0_filterType_Offset = 4
|
||||
|
||||
wtFwpV4AddrAndMask_Size = 8
|
||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
||||
|
||||
wtFwpV6AddrAndMask_Size = 17
|
||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
||||
)
|
||||
|
||||
type wtFwpActionFlag uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
||||
)
|
||||
|
||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
||||
type wtFwpActionType uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
||||
)
|
||||
|
||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
||||
type wtFwpByteBlob struct {
|
||||
size uint32
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
||||
type wtFwpMatchType uint32
|
||||
|
||||
const (
|
||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
||||
)
|
||||
|
||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
||||
type wtFwpmAction0 struct {
|
||||
_type wtFwpActionType
|
||||
filterType windows.GUID // Windows type: GUID
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
||||
Data1: 0x4cd62a49,
|
||||
Data2: 0x59c3,
|
||||
Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
||||
Data1: 0xb235ae9a,
|
||||
Data2: 0x1d64,
|
||||
Data3: 0x49b8,
|
||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
||||
Data1: 0x3971ef2b,
|
||||
Data2: 0x623e,
|
||||
Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
||||
Data1: 0x0c1ba1af,
|
||||
Data2: 0x5765,
|
||||
Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
||||
Data1: 0xc35a604d,
|
||||
Data2: 0xd22b,
|
||||
Data3: 0x4e1a,
|
||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
||||
Data1: 0xd78e1e87,
|
||||
Data2: 0x8644,
|
||||
Data3: 0x4ea5,
|
||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
||||
}
|
||||
|
||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
||||
Data1: 0xaf043a0a,
|
||||
Data2: 0xb34d,
|
||||
Data3: 0x4f86,
|
||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
||||
}
|
||||
|
||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
||||
Data1: 0xd9ee00de,
|
||||
Data2: 0xc1ef,
|
||||
Data3: 0x4617,
|
||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
||||
}
|
||||
|
||||
var (
|
||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
||||
)
|
||||
|
||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
||||
Data1: 0x7bc43cbf,
|
||||
Data2: 0x37ba,
|
||||
Data3: 0x45f1,
|
||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
||||
}
|
||||
|
||||
type wtFwpmL2Flags uint32
|
||||
|
||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
||||
|
||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
||||
Data1: 0x632ce23b,
|
||||
Data2: 0x5167,
|
||||
Data3: 0x435c,
|
||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
||||
}
|
||||
|
||||
type wtFwpmFlags uint32
|
||||
|
||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
||||
|
||||
// Defined in fwpmtypes.h
|
||||
type wtFwpmFilterFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
||||
)
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
||||
Data1: 0xc38d57d1,
|
||||
Data2: 0x05a7,
|
||||
Data3: 0x4c33,
|
||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
||||
}
|
||||
|
||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7,
|
||||
Data2: 0xf4b5,
|
||||
Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
||||
Data1: 0x4a72393b,
|
||||
Data2: 0x319f,
|
||||
Data3: 0x44bc,
|
||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
||||
}
|
||||
|
||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
||||
Data1: 0xa3b42c97,
|
||||
Data2: 0x9f04,
|
||||
Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
|
||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0x94c44912,
|
||||
Data2: 0x9d6f,
|
||||
Data3: 0x4ebf,
|
||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
||||
}
|
||||
|
||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0xd4220bd3,
|
||||
Data2: 0x62ce,
|
||||
Data3: 0x4f08,
|
||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
||||
}
|
||||
|
||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
||||
type wtFwpBitmapArray64 struct {
|
||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
||||
type wtFwpByteArray6 struct {
|
||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
||||
type wtFwpByteArray16 struct {
|
||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
||||
}
|
||||
|
||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
||||
type wtFwpConditionValue0 wtFwpValue0
|
||||
|
||||
// FWP_DATA_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
||||
type wtFwpDataType uint
|
||||
|
||||
const (
|
||||
cFWP_EMPTY wtFwpDataType = 0
|
||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
||||
)
|
||||
|
||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
||||
type wtFwpV4AddrAndMask struct {
|
||||
addr uint32
|
||||
mask uint32
|
||||
}
|
||||
|
||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
||||
type wtFwpV6AddrAndMask struct {
|
||||
addr [16]uint8
|
||||
prefixLength uint8
|
||||
}
|
||||
|
||||
// FWP_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
||||
type wtFwpValue0 struct {
|
||||
_type wtFwpDataType
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
||||
type wtFwpmDisplayData0 struct {
|
||||
name *uint16 // Windows type: *wchar_t
|
||||
description *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
||||
type wtFwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // Windows type: GUID
|
||||
matchType wtFwpMatchType
|
||||
conditionValue wtFwpConditionValue0
|
||||
}
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
||||
type wtFwpProvider0 struct {
|
||||
providerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
type wtFwpmSessionFlagsValue uint32
|
||||
|
||||
const (
|
||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
||||
type wtFwpmSession0 struct {
|
||||
sessionKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32 // Windows type: DWORD
|
||||
sid *windows.SID
|
||||
username *uint16 // Windows type: *wchar_t
|
||||
kernelMode uint8 // Windows type: BOOL
|
||||
}
|
||||
|
||||
type wtFwpmSublayerFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
||||
type wtFwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSublayerFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
weight uint16
|
||||
}
|
||||
|
||||
// Defined in rpcdce.h
|
||||
type wtRpcCAuthN uint32
|
||||
|
||||
const (
|
||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
||||
type wtFwpmProvider0 struct {
|
||||
providerKey windows.GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16
|
||||
}
|
||||
|
||||
type wtIPProto uint32
|
||||
|
||||
const (
|
||||
cIPPROTO_ICMP wtIPProto = 1
|
||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
||||
cIPPROTO_TCP wtIPProto = 6
|
||||
cIPPROTO_UDP wtIPProto = 17
|
||||
)
|
||||
|
||||
const (
|
||||
cFWP_ACTRL_MATCH_FILTER = 1
|
||||
)
|
||||
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build windows && (386 || arm)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 8
|
||||
wtFwpByteBlob_data_Offset = 4
|
||||
|
||||
wtFwpConditionValue0_Size = 8
|
||||
wtFwpConditionValue0_uint8_Offset = 4
|
||||
|
||||
wtFwpmDisplayData0_Size = 8
|
||||
wtFwpmDisplayData0_description_Offset = 4
|
||||
|
||||
wtFwpmFilter0_Size = 152
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 24
|
||||
wtFwpmFilter0_providerKey_Offset = 28
|
||||
wtFwpmFilter0_providerData_Offset = 32
|
||||
wtFwpmFilter0_layerKey_Offset = 40
|
||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
||||
wtFwpmFilter0_weight_Offset = 72
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
||||
wtFwpmFilter0_filterCondition_Offset = 84
|
||||
wtFwpmFilter0_action_Offset = 88
|
||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
||||
wtFwpmFilter0_reserved_Offset = 128
|
||||
wtFwpmFilter0_filterID_Offset = 136
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
||||
|
||||
wtFwpmFilterCondition0_Size = 28
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
||||
|
||||
wtFwpmSession0_Size = 48
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 24
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
||||
wtFwpmSession0_processId_Offset = 32
|
||||
wtFwpmSession0_sid_Offset = 36
|
||||
wtFwpmSession0_username_Offset = 40
|
||||
wtFwpmSession0_kernelMode_Offset = 44
|
||||
|
||||
wtFwpmSublayer0_Size = 44
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 24
|
||||
wtFwpmSublayer0_providerKey_Offset = 28
|
||||
wtFwpmSublayer0_providerData_Offset = 32
|
||||
wtFwpmSublayer0_weight_Offset = 40
|
||||
|
||||
wtFwpProvider0_Size = 40
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 24
|
||||
wtFwpProvider0_providerData_Offset = 28
|
||||
wtFwpProvider0_serviceName_Offset = 36
|
||||
|
||||
wtFwpTokenInformation_Size = 16
|
||||
|
||||
wtFwpValue0_Size = 8
|
||||
wtFwpValue0_value_Offset = 4
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
offset2 [4]byte // Layout correction field
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build windows && (amd64 || arm64)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 16
|
||||
wtFwpByteBlob_data_Offset = 8
|
||||
|
||||
wtFwpConditionValue0_Size = 16
|
||||
wtFwpConditionValue0_uint8_Offset = 8
|
||||
|
||||
wtFwpmDisplayData0_Size = 16
|
||||
wtFwpmDisplayData0_description_Offset = 8
|
||||
|
||||
wtFwpmFilter0_Size = 200
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 32
|
||||
wtFwpmFilter0_providerKey_Offset = 40
|
||||
wtFwpmFilter0_providerData_Offset = 48
|
||||
wtFwpmFilter0_layerKey_Offset = 64
|
||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
||||
wtFwpmFilter0_weight_Offset = 96
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
||||
wtFwpmFilter0_filterCondition_Offset = 120
|
||||
wtFwpmFilter0_action_Offset = 128
|
||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
||||
wtFwpmFilter0_reserved_Offset = 168
|
||||
wtFwpmFilter0_filterID_Offset = 176
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
||||
|
||||
wtFwpmFilterCondition0_Size = 40
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
||||
|
||||
wtFwpmSession0_Size = 72
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 32
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
||||
wtFwpmSession0_processId_Offset = 40
|
||||
wtFwpmSession0_sid_Offset = 48
|
||||
wtFwpmSession0_username_Offset = 56
|
||||
wtFwpmSession0_kernelMode_Offset = 64
|
||||
|
||||
wtFwpmSublayer0_Size = 72
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 32
|
||||
wtFwpmSublayer0_providerKey_Offset = 40
|
||||
wtFwpmSublayer0_providerData_Offset = 48
|
||||
wtFwpmSublayer0_weight_Offset = 64
|
||||
|
||||
wtFwpProvider0_Size = 64
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 32
|
||||
wtFwpProvider0_providerData_Offset = 40
|
||||
wtFwpProvider0_serviceName_Offset = 56
|
||||
|
||||
wtFwpValue0_Size = 16
|
||||
wtFwpValue0_value_Offset = 8
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
|
||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
||||
)
|
||||
|
||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
@@ -71,6 +72,7 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
dnsFirewall dnsfw.Manager
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -90,8 +92,9 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
}
|
||||
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
if err := configurator.configureInterface(); err != nil {
|
||||
@@ -169,16 +172,8 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
if err := r.applyRouteAll(config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
@@ -220,6 +215,35 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
||||
if config.RouteAll {
|
||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("dns firewall: %w", err)
|
||||
}
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
if !r.routingAll {
|
||||
return nil
|
||||
}
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
@@ -406,6 +430,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
@@ -34,8 +36,9 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
@@ -134,8 +137,9 @@ func TestNRPTDomainBatching(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -413,40 +413,25 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||
var routeSelection []RoutesSelectionInfo
|
||||
for _, r := range routes {
|
||||
// resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org),
|
||||
// not the configured pattern (e.g. *.ipify.org). Group entries whose
|
||||
// ParentDomain belongs to this route, mirroring the daemon logic in
|
||||
// client/server/network.go.
|
||||
domainList := make([]DomainInfo, 0, len(r.Domains))
|
||||
domainIndex := make(map[domain.Domain]int, len(r.Domains))
|
||||
domainList := make([]DomainInfo, 0)
|
||||
for _, d := range r.Domains {
|
||||
domainIndex[d] = len(domainList)
|
||||
domainList = append(domainList, DomainInfo{Domain: d.SafeString()})
|
||||
}
|
||||
|
||||
for _, info := range resolvedDomains {
|
||||
idx, ok := domainIndex[info.ParentDomain]
|
||||
if !ok {
|
||||
continue
|
||||
domainResp := DomainInfo{
|
||||
Domain: d.SafeString(),
|
||||
}
|
||||
for _, prefix := range info.Prefixes {
|
||||
domainList[idx].AddResolvedIP(prefix.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
if info, exists := resolvedDomains[d]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range info.Prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||
}
|
||||
domainList = append(domainList, domainResp)
|
||||
}
|
||||
domainDetails := DomainDetails{items: domainList}
|
||||
|
||||
// For dynamic (DNS) routes, expose the joined domain pattern as the
|
||||
// Network value so it matches the peer.routes entries on the Swift
|
||||
// side (mirroring the Android bridge in client/android/client.go).
|
||||
netStr := r.Network.String()
|
||||
if len(r.Domains) > 0 {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
||||
ID: r.NetID,
|
||||
Network: netStr,
|
||||
Network: r.Network.String(),
|
||||
Domains: &domainDetails,
|
||||
Selected: r.Selected,
|
||||
})
|
||||
|
||||
@@ -34,34 +34,7 @@ type DomainDetails struct {
|
||||
|
||||
type DomainInfo struct {
|
||||
Domain string
|
||||
resolvedIPs ResolvedIPs
|
||||
}
|
||||
|
||||
func (d *DomainInfo) AddResolvedIP(ipAddress string) {
|
||||
d.resolvedIPs.Add(ipAddress)
|
||||
}
|
||||
|
||||
func (d *DomainInfo) GetResolvedIPs() *ResolvedIPs {
|
||||
return &d.resolvedIPs
|
||||
}
|
||||
|
||||
type ResolvedIPs struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||
r.items = append(r.items, ipAddress)
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Get(i int) string {
|
||||
if i < 0 || i >= len(r.items) {
|
||||
return ""
|
||||
}
|
||||
return r.items[i]
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Size() int {
|
||||
return len(r.items)
|
||||
ResolvedIPs string
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
|
||||
@@ -89,33 +89,21 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector in Dex storage.
|
||||
// It overlays user-mutable config fields (issuer, clientID, clientSecret,
|
||||
// redirectURI) onto the stored connector config, and updates the connector name
|
||||
// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so
|
||||
// partial updates preserve create-time defaults such as scopes, claimMapping,
|
||||
// and userIDKey.
|
||||
// It merges incoming updates with existing values to prevent data loss on partial updates.
|
||||
func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error {
|
||||
if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) {
|
||||
if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) {
|
||||
return storage.Connector{}, errors.New("connector type change not allowed")
|
||||
}
|
||||
|
||||
configData, err := overlayConnectorConfig(old.Config, cfg)
|
||||
oldCfg, err := p.parseStorageConnector(old)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err)
|
||||
return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err)
|
||||
}
|
||||
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = old.Name
|
||||
}
|
||||
mergeConnectorConfig(cfg, oldCfg)
|
||||
|
||||
return storage.Connector{
|
||||
ID: cfg.ID,
|
||||
Type: old.Type,
|
||||
Name: name,
|
||||
Config: configData,
|
||||
}, nil
|
||||
storageConn, err := p.buildStorageConnector(cfg)
|
||||
if err != nil {
|
||||
return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err)
|
||||
}
|
||||
return storageConn, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update connector: %w", err)
|
||||
}
|
||||
@@ -124,27 +112,23 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayConnectorConfig writes only the user-mutable fields onto the existing
|
||||
// stored config, preserving every other field (scopes, claimMapping, userIDKey,
|
||||
// insecure flags, etc.). Empty fields on cfg leave the existing value alone.
|
||||
func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) {
|
||||
var m map[string]any
|
||||
if err := decodeConnectorConfig(oldConfig, &m); err != nil {
|
||||
return nil, err
|
||||
// mergeConnectorConfig preserves existing values for empty fields in the update.
|
||||
func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) {
|
||||
if cfg.ClientSecret == "" {
|
||||
cfg.ClientSecret = oldCfg.ClientSecret
|
||||
}
|
||||
if cfg.Issuer != "" {
|
||||
m["issuer"] = cfg.Issuer
|
||||
if cfg.RedirectURI == "" {
|
||||
cfg.RedirectURI = oldCfg.RedirectURI
|
||||
}
|
||||
if cfg.ClientID != "" {
|
||||
m["clientID"] = cfg.ClientID
|
||||
if cfg.Issuer == "" && cfg.Type == oldCfg.Type {
|
||||
cfg.Issuer = oldCfg.Issuer
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
m["clientSecret"] = cfg.ClientSecret
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = oldCfg.ClientID
|
||||
}
|
||||
if cfg.RedirectURI != "" {
|
||||
m["redirectURI"] = cfg.RedirectURI
|
||||
if cfg.Name == "" {
|
||||
cfg.Name = oldCfg.Name
|
||||
}
|
||||
return encodeConnectorConfig(m)
|
||||
}
|
||||
|
||||
// DeleteConnector removes a connector from Dex storage.
|
||||
@@ -232,10 +216,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["getUserInfo"] = true
|
||||
case "entra":
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
// Use the Entra Object ID (oid) instead of the default OIDC sub claim.
|
||||
// Entra issues sub as a per-app pairwise identifier that does not match
|
||||
// the stable Object ID.
|
||||
oidcConfig["userIDKey"] = "oid"
|
||||
case "okta":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestProvider(t *testing.T) (*Provider, func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "dex-connector-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &Provider{storage: s, logger: logger}, func() {
|
||||
_ = s.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) {
|
||||
cfg := &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
|
||||
assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"])
|
||||
}
|
||||
|
||||
func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) {
|
||||
// ensures the Entra userIDKey override does not leak into other OIDC providers,
|
||||
// which already use a stable sub claim.
|
||||
for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} {
|
||||
t.Run(typ, func(t *testing.T) {
|
||||
data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &m))
|
||||
_, ok := m["userIDKey"]
|
||||
assert.False(t, ok, "%s connectors must not have userIDKey set", typ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
created, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "old-secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "entra-test", created.ID)
|
||||
|
||||
// Rotate only the client secret.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated")
|
||||
assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)")
|
||||
assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"])
|
||||
assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update")
|
||||
assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
// Seed a connector directly into storage without userIDKey
|
||||
preFixConfig, err := json.Marshal(map[string]any{
|
||||
"issuer": "https://login.microsoftonline.com/tid/v2.0",
|
||||
"clientID": "client-id",
|
||||
"clientSecret": "old-secret",
|
||||
"redirectURI": "https://example.com/oauth2/callback",
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"claimMapping": map[string]string{"email": "preferred_username"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{
|
||||
ID: "entra-prefix",
|
||||
Type: "oidc",
|
||||
Name: "Entra",
|
||||
Config: preFixConfig,
|
||||
}))
|
||||
|
||||
// Rotate client secret via UpdateConnector.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-prefix",
|
||||
Type: "entra",
|
||||
ClientSecret: "new-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-prefix")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
|
||||
assert.Equal(t, "new-secret", m["clientSecret"])
|
||||
_, has := m["userIDKey"]
|
||||
assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before")
|
||||
}
|
||||
|
||||
func TestUpdateConnector_RejectsTypeChange(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/tid/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to switch the connector to okta.
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "okta",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "connector type change not allowed")
|
||||
|
||||
// stored connector type/config unchanged after the rejected update.
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc", conn.Type)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "oid", m["userIDKey"])
|
||||
}
|
||||
|
||||
func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
p, cleanup := newTestProvider(t)
|
||||
defer cleanup()
|
||||
|
||||
_, err := p.CreateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Name: "Entra",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/old/v2.0",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
RedirectURI: "https://example.com/oauth2/callback",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = p.UpdateConnector(ctx, &ConnectorConfig{
|
||||
ID: "entra-test",
|
||||
Type: "entra",
|
||||
Issuer: "https://login.microsoftonline.com/new/v2.0",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := p.storage.GetConnector(ctx, "entra-test")
|
||||
require.NoError(t, err)
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(conn.Config, &m))
|
||||
assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"])
|
||||
}
|
||||
@@ -31,7 +31,6 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -72,8 +71,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
var ret []*domain.Domain
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
@@ -127,8 +126,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters for this account
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -274,7 +273,7 @@ func (m Manager) GetClusterDomains() []string {
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -299,17 +298,6 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||
}
|
||||
|
||||
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
bestCluster := ""
|
||||
bestLen := -1
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveClusterAddressesFunc != nil {
|
||||
return m.getActiveClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
@@ -11,19 +11,15 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
|
||||
@@ -16,16 +16,11 @@ type store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -49,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -60,10 +55,9 @@ func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddres
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: proxy.StatusConnected,
|
||||
Status: "connected",
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
@@ -83,7 +77,7 @@ func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddres
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
@@ -98,7 +92,7 @@ func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) err
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
@@ -110,7 +104,7 @@ func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -119,6 +113,16 @@ func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, erro
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
clusters, err := m.store.GetActiveProxyClusters(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
@@ -138,44 +142,10 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if m.deleteAccountClusterFunc != nil {
|
||||
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "session-1", savedProxy.SessionID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byop.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteAccountCluster(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedCluster, deletedAccount string
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
|
||||
deletedCluster = clusterAddress
|
||||
deletedAccount = accountID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cluster.example.com", deletedCluster)
|
||||
assert.Equal(t, "acc-123", deletedAccount)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byop.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -136,17 +136,19 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
|
||||
ret0, _ := ret[0].([]Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
@@ -163,65 +165,6 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// GetAccountProxy mocks base method.
|
||||
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||
}
|
||||
|
||||
// CountAccountProxies mocks base method.
|
||||
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable mocks base method.
|
||||
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
)
|
||||
import "time"
|
||||
|
||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||
// Nil fields mean the proxy never reported this capability.
|
||||
@@ -28,7 +21,6 @@ type Proxy struct {
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -44,8 +36,6 @@ func (Proxy) TableName() string {
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
}
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ProxyTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || len(req.Name) > 255 {
|
||||
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var expiresIn time.Duration
|
||||
if req.ExpiresIn != nil {
|
||||
if *req.ExpiresIn < 0 {
|
||||
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
if *req.ExpiresIn > 0 {
|
||||
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
accountID := userAuth.AccountId
|
||||
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := mux.Vars(r)["tokenId"]
|
||||
if tokenID == "" {
|
||||
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
} else {
|
||||
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
resp := api.ProxyToken{
|
||||
Id: token.ID,
|
||||
Name: token.Name,
|
||||
Revoked: token.Revoked,
|
||||
}
|
||||
if !token.CreatedAt.IsZero() {
|
||||
resp.CreatedAt = token.CreatedAt
|
||||
}
|
||||
if token.ExpiresAt != nil {
|
||||
resp.ExpiresAt = token.ExpiresAt
|
||||
}
|
||||
if token.LastUsed != nil {
|
||||
resp.LastUsed = token.LastUsed
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||
plainToken := string(generated.PlainToken)
|
||||
return api.ProxyTokenCreated{
|
||||
Id: base.Id,
|
||||
Name: base.Name,
|
||||
CreatedAt: base.CreatedAt,
|
||||
ExpiresAt: base.ExpiresAt,
|
||||
LastUsed: base.LastUsed,
|
||||
Revoked: base.Revoked,
|
||||
PlainToken: plainToken,
|
||||
}
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "my-token"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp api.ProxyTokenCreated
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
assert.NotEmpty(t, resp.PlainToken)
|
||||
assert.Equal(t, "my-token", resp.Name)
|
||||
assert.False(t, resp.Revoked)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.AccountID)
|
||||
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||
}
|
||||
|
||||
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.ExpiresAt)
|
||||
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateToken_EmptyName(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": ""}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "test"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestListTokens(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listTokens(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.ProxyToken
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 2)
|
||||
assert.Equal(t, "tok-1", resp[0].Id)
|
||||
assert.False(t, resp[0].Revoked)
|
||||
assert.Equal(t, "tok-2", resp[1].Id)
|
||||
assert.True(t, resp[1].Revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
Name: "test-token",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
otherAccount := "acc-other"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: &otherAccount,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: nil,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
@@ -29,5 +28,4 @@ type Manager interface {
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||
}
|
||||
|
||||
@@ -79,20 +79,6 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,21 +138,6 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -35,7 +35,6 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
@@ -196,33 +195,10 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
})
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiClusters)
|
||||
}
|
||||
|
||||
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
clusterAddress := mux.Vars(r)["clusterAddress"]
|
||||
if clusterAddress == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -121,21 +121,7 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
// owned by the account.
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
|
||||
return m.store.GetActiveProxyClusters(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
@@ -999,10 +985,6 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
|
||||
@@ -433,7 +433,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -712,7 +712,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1135,7 +1135,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -9,11 +9,8 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -51,11 +48,6 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ProxyTokenChecker checks whether a proxy access token is still valid.
|
||||
type ProxyTokenChecker interface {
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -84,56 +76,22 @@ type ProxyServiceServer struct {
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
// Checker for proxy access token validity
|
||||
tokenChecker ProxyTokenChecker
|
||||
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||
tokenTTL time.Duration
|
||||
|
||||
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||
snapshotBatchSize int
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
const defaultProxyTokenTTL = 5 * time.Minute
|
||||
|
||||
const defaultSnapshotBatchSize = 500
|
||||
|
||||
func snapshotBatchSizeFromEnv() int {
|
||||
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultSnapshotBatchSize
|
||||
}
|
||||
|
||||
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||
if s.tokenTTL > 0 {
|
||||
return s.tokenTTL
|
||||
}
|
||||
return defaultProxyTokenTTL
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
accountID *string
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
@@ -141,19 +99,8 @@ type proxyConnection struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token == nil || token.AccountID == nil {
|
||||
return nil
|
||||
}
|
||||
if requestAccountID == "" || *token.AccountID != requestAccountID {
|
||||
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -163,8 +110,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
@@ -223,25 +168,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
@@ -259,8 +185,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
@@ -268,6 +192,12 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
@@ -276,43 +206,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.CompareAndDelete(proxyID, conn)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
@@ -333,7 +241,14 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -345,9 +260,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -357,19 +271,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnectio
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
|
||||
if conn.tokenID != "" && s.tokenChecker != nil {
|
||||
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
@@ -377,6 +278,8 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnectio
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
@@ -387,40 +290,29 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return err
|
||||
}
|
||||
|
||||
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||
// staying well within the default 4 MB message size limit.
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
var services []*rpservice.Service
|
||||
var err error
|
||||
if conn.accountID != nil {
|
||||
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
|
||||
} else {
|
||||
services, err = s.serviceManager.GetGlobalServices(ctx)
|
||||
}
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -431,9 +323,13 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
log.WithFields(log.Fields{
|
||||
"service": service.Name,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
||||
continue
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
@@ -445,14 +341,8 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address (domain name or IP address)
|
||||
// isProxyAddressValid validates a proxy address
|
||||
func isProxyAddressValid(addr string) bool {
|
||||
if addr == "" {
|
||||
return false
|
||||
}
|
||||
if net.ParseIP(addr) != nil {
|
||||
return true
|
||||
}
|
||||
_, err := domain.ValidateDomains([]string{addr})
|
||||
return err == nil
|
||||
}
|
||||
@@ -476,10 +366,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
|
||||
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"service_id": accessLog.GetServiceId(),
|
||||
"account_id": accessLog.GetAccountId(),
|
||||
@@ -517,68 +403,24 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
// BYOP proxies only receive updates for their own account's services.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
updateAccountIDs := make(map[string]struct{})
|
||||
for _, m := range update.Mapping {
|
||||
if m.AccountId != "" {
|
||||
updateAccountIDs[m.AccountId] = struct{}{}
|
||||
}
|
||||
}
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
connUpdate := update
|
||||
if conn.accountID != nil && len(updateAccountIDs) > 0 {
|
||||
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
|
||||
return true
|
||||
}
|
||||
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
|
||||
if len(filtered) == 0 {
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
|
||||
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
conn.cancel()
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
|
||||
var filtered []*proto.ProxyMapping
|
||||
for _, m := range mappings {
|
||||
if m.AccountId == accountID {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// GetConnectedProxies returns a list of connected proxy IDs
|
||||
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||
var proxies []string
|
||||
@@ -647,25 +489,19 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
|
||||
continue
|
||||
}
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -691,8 +527,7 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails; the caller must disconnect the
|
||||
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
@@ -701,7 +536,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
@@ -737,10 +572,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
@@ -860,10 +691,6 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
protoStatus := req.GetStatus()
|
||||
@@ -934,10 +761,6 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
|
||||
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
||||
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceID := req.GetServiceId()
|
||||
accountID := req.GetAccountId()
|
||||
token := req.GetToken()
|
||||
@@ -992,10 +815,6 @@ func strPtr(s string) *string {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
@@ -1124,9 +943,21 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
break
|
||||
}
|
||||
}
|
||||
if service == nil {
|
||||
return "", fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
if service.SessionPrivateKey == "" {
|
||||
@@ -1224,10 +1055,6 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -1311,7 +1138,18 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if service.Domain == domain {
|
||||
return service, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsProxyAddressValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
valid bool
|
||||
}{
|
||||
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
|
||||
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
|
||||
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
|
||||
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
|
||||
{name: "valid IPv6", addr: "::1", valid: true},
|
||||
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
|
||||
{name: "empty string", addr: "", valid: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,6 +153,9 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||
|
||||
if !token.IsValid() {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||
}
|
||||
|
||||
@@ -53,10 +53,6 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -95,20 +91,6 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
|
||||
|
||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, services := range m.proxiesByAccount {
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// recordingStream captures all messages sent via Send so tests can inspect
|
||||
// batching behaviour without a real gRPC transport.
|
||||
type recordingStream struct {
|
||||
grpc.ServerStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
}
|
||||
|
||||
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
s.messages = append(s.messages, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// makeServices creates n enabled services assigned to the given cluster.
|
||||
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||
services := make([]*rpservice.Service, n)
|
||||
for i := range n {
|
||||
services[i] = &rpservice.Service{
|
||||
ID: fmt.Sprintf("svc-%d", i),
|
||||
AccountID: "acct-1",
|
||||
Name: fmt.Sprintf("svc-%d", i),
|
||||
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||
ProxyCluster: cluster,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return services
|
||||
}
|
||||
|
||||
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||
t.Helper()
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||
snapshotBatchSize: batchSize,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshot(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect ceil(7/3) = 3 messages
|
||||
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||
|
||||
// Verify all service IDs are present exactly once
|
||||
seen := make(map[string]bool)
|
||||
for _, msg := range stream.messages {
|
||||
for _, m := range msg.Mapping {
|
||||
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||
seen[m.Id] = true
|
||||
}
|
||||
}
|
||||
assert.Len(t, seen, totalServices)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 6 // exactly 2 batches
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 2)
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
@@ -12,12 +12,9 @@ import (
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -88,14 +85,11 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
capabilities: caps,
|
||||
sendChan: ch,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
@@ -319,58 +313,6 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func scopedCtx(accountID string) context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-1",
|
||||
AccountID: &accountID,
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func globalCtx() context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-global",
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
|
||||
err := enforceAccountScope(globalCtx(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "acc-2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
|
||||
err := enforceAccountScope(context.Background(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
@@ -318,17 +318,13 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
|
||||
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -344,10 +340,6 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -356,22 +348,6 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3074,7 +3074,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -145,9 +144,6 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
if proxyGRPCServer != nil {
|
||||
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||
|
||||
@@ -7,8 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -42,6 +45,11 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
|
||||
|
||||
// getAllCountries retrieves a list of all countries
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
@@ -63,6 +71,11 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
@@ -89,6 +102,27 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toCountryResponse(country geolocation.Country) api.Country {
|
||||
return api.Country{
|
||||
CountryName: country.CountryName,
|
||||
|
||||
@@ -216,7 +216,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
@@ -390,10 +389,6 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -440,10 +435,6 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
|
||||
|
||||
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -4495,47 +4495,6 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var tokens []*types.ProxyAccessToken
|
||||
result := tx.Where("account_id = ?", accountID).Find(&tokens)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return token.IsValid(), nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var token types.ProxyAccessToken
|
||||
result := tx.Take(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).
|
||||
@@ -5486,7 +5445,7 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND session_id = ?", proxyID, sessionID).
|
||||
Updates(map[string]any{
|
||||
"status": proxy.StatusDisconnected,
|
||||
"status": "disconnected",
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
@@ -5517,7 +5476,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
|
||||
if result.RowsAffected == 0 {
|
||||
p.LastSeen = now
|
||||
p.ConnectedAt = &now
|
||||
p.Status = proxy.StatusConnected
|
||||
p.Status = "connected"
|
||||
if err := s.db.Create(p).Error; err != nil {
|
||||
log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err)
|
||||
}
|
||||
@@ -5526,15 +5485,13 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns the unique cluster addresses of active
|
||||
// shared proxies (those without an account scope). BYOP cluster addresses are
|
||||
// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them.
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
@@ -5546,75 +5503,13 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
var p proxy.Proxy
|
||||
result := s.db.Where("account_id = ?", accountID).Take(&p)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy not found for account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
var count int64
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
result := s.db.
|
||||
Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID).
|
||||
Delete(&proxy.Proxy{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "delete account cluster: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "cluster not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
var clusters []proxy.Cluster
|
||||
|
||||
result := s.db.Model(&proxy.Proxy{}).
|
||||
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
|
||||
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
|
||||
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
|
||||
Select("cluster_address as address, COUNT(*) as connected_proxies").
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Group("cluster_address").
|
||||
Scan(&clusters)
|
||||
|
||||
|
||||
@@ -114,9 +114,6 @@ type Store interface {
|
||||
|
||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||
@@ -290,16 +287,11 @@ type Store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
@@ -503,9 +495,6 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,6 +165,20 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -223,21 +237,6 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID mocks base method.
|
||||
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// CreateAccessLog mocks base method.
|
||||
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -576,20 +575,6 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteRoute mocks base method.
|
||||
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1316,34 +1301,19 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddressesForAccount mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
@@ -1419,20 +1389,6 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2002,51 +1958,6 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
|
||||
ret0, _ := ret[0].(*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetProxyByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(*proxy.Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2479,21 +2390,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting mocks base method.
|
||||
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
|
||||
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// IsPrimaryAccount mocks base method.
|
||||
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2510,21 +2406,6 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid mocks base method.
|
||||
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
|
||||
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
|
||||
}
|
||||
|
||||
// ListCustomDomains mocks base method.
|
||||
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -144,11 +144,8 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
replacement := "X"
|
||||
if plainToken[5] == 'X' {
|
||||
replacement = "Y"
|
||||
}
|
||||
modifiedToken := plainToken[:5] + replacement + plainToken[6:]
|
||||
// Modify one character in the secret part
|
||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
||||
err = ValidateInviteToken(modifiedToken)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -1,409 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type byopTestSetup struct {
|
||||
store store.Store
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
grpcServer *grpc.Server
|
||||
grpcAddr string
|
||||
cleanup func()
|
||||
|
||||
accountA string
|
||||
accountB string
|
||||
accountAToken types.PlainProxyToken
|
||||
accountBToken types.PlainProxyToken
|
||||
accountACluster string
|
||||
accountBCluster string
|
||||
}
|
||||
|
||||
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountAID := "byop-account-a"
|
||||
accountBID := "byop-account-b"
|
||||
|
||||
for _, acc := range []*types.Account{
|
||||
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
} {
|
||||
require.NoError(t, testStore.SaveAccount(ctx, acc))
|
||||
}
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
clusterA := "byop-a.proxy.test"
|
||||
clusterB := "byop-b.proxy.test"
|
||||
|
||||
services := []*service.Service{
|
||||
{
|
||||
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
|
||||
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
|
||||
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
|
||||
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
}
|
||||
for _, svc := range services {
|
||||
require.NoError(t, testStore.CreateService(ctx, svc))
|
||||
}
|
||||
|
||||
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
|
||||
|
||||
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
|
||||
|
||||
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
|
||||
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
realProxyManager, err := proxymanager.NewManager(testStore, meter)
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
Issuer: "https://fake-issuer.example.com",
|
||||
ClientID: "test-client",
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
}
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
pkceStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||
proxyService.SetServiceManager(svcMgr)
|
||||
|
||||
proxyController := &testProxyController{}
|
||||
proxyService.SetProxyController(proxyController)
|
||||
|
||||
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
|
||||
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
t.Logf("gRPC server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &byopTestSetup{
|
||||
store: testStore,
|
||||
proxyService: proxyService,
|
||||
grpcServer: grpcServer,
|
||||
grpcAddr: lis.Addr().String(),
|
||||
cleanup: func() {
|
||||
grpcServer.GracefulStop()
|
||||
authClose()
|
||||
storeCleanup()
|
||||
},
|
||||
accountA: accountAID,
|
||||
accountB: accountBID,
|
||||
accountAToken: tokenA.PlainToken,
|
||||
accountBToken: tokenB.PlainToken,
|
||||
accountACluster: clusterA,
|
||||
accountBCluster: clusterB,
|
||||
}
|
||||
}
|
||||
|
||||
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
|
||||
md := metadata.Pairs("authorization", "Bearer "+string(token))
|
||||
return metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
t.Helper()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
|
||||
for _, m := range mappings {
|
||||
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
|
||||
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, m := range mappings {
|
||||
ids[m.GetId()] = true
|
||||
}
|
||||
assert.True(t, ids["svc-a1"], "should contain svc-a1")
|
||||
assert.True(t, ids["svc-a2"], "should contain svc-a2")
|
||||
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountBCluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
|
||||
assert.Equal(t, "svc-b1", mappings[0].GetId())
|
||||
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-first",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings1 := receiveBYOPMappings(t, stream1)
|
||||
assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services")
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-second",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings2 := receiveBYOPMappings(t, stream2)
|
||||
assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services")
|
||||
for _, m := range mappings2 {
|
||||
assert.Equal(t, setup.accountA, m.GetAccountId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-cluster",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = receiveBYOPMappings(t, stream1)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b-conflict",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream2.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
|
||||
t.Logf("expected rejection: %s", st.Message())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
proxyID := "byop-proxy-reconnect"
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveBYOPMappings(t, stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveBYOPMappings(t, stream2)
|
||||
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
|
||||
|
||||
firstIDs := map[string]bool{}
|
||||
for _, m := range firstMappings {
|
||||
firstIDs[m.GetId()] = true
|
||||
}
|
||||
for _, m := range secondMappings {
|
||||
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "no-auth-proxy",
|
||||
Version: "test-v1",
|
||||
Address: "some.cluster.io",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.Unauthenticated, st.Code())
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -141,7 +140,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
// Use store-backed service manager
|
||||
@@ -203,8 +201,8 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
@@ -219,10 +217,6 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -243,22 +237,6 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
||||
type testProxyController struct{}
|
||||
|
||||
@@ -312,10 +290,6 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -362,10 +336,6 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
|
||||
|
||||
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -394,16 +364,14 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive 2 mappings total
|
||||
@@ -443,14 +411,12 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
mappings := make([]*proto.ProxyMapping, 0)
|
||||
for {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive the 2 mappings matching the cluster
|
||||
@@ -474,15 +440,13 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
// Helper to receive all mappings from a stream
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
||||
var mappings []*proto.ProxyMapping
|
||||
for {
|
||||
for i := 0; i < count; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
@@ -496,7 +460,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveMappings(stream1)
|
||||
firstMappings := receiveMappings(stream1, 2)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -512,7 +476,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveMappings(stream2)
|
||||
secondMappings := receiveMappings(stream2, 2)
|
||||
|
||||
// Should receive the same mappings
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||
@@ -578,14 +542,12 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to receive and apply all mappings
|
||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||
for {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
applyMappings(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,14 +636,12 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
count := 0
|
||||
for {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
count += len(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
@@ -721,12 +681,9 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream1.Recv()
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream1.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
@@ -742,12 +699,9 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream2.Recv()
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := stream2.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
@@ -943,8 +943,6 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
operation := func() error {
|
||||
s.Logger.Debug("connecting to management mapping stream")
|
||||
|
||||
initialSyncDone = false
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
@@ -1002,11 +1000,6 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
@@ -1027,45 +1020,17 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
s.portMu.RLock()
|
||||
var stale []*proto.ProxyMapping
|
||||
for svcID, mapping := range s.lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, mapping)
|
||||
}
|
||||
}
|
||||
s.portMu.RUnlock()
|
||||
|
||||
for _, mapping := range stale {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Info("Removing stale mapping absent from snapshot")
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||
// so we can verify it without triggering removeMapping (which requires full
|
||||
// server wiring). This keeps the test focused on the detection algorithm.
|
||||
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||
var stale []types.ServiceID
|
||||
for svcID := range lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, svcID)
|
||||
}
|
||||
}
|
||||
return stale
|
||||
}
|
||||
|
||||
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||
// the snapshot are flagged as stale.
|
||||
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
"svc-stale-a": {Id: "svc-stale-a"},
|
||||
"svc-stale-b": {Id: "svc-stale-b"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
"svc-3": {}, // new service, not in local
|
||||
}
|
||||
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Len(t, stale, 2)
|
||||
staleSet := make(map[types.ServiceID]struct{})
|
||||
for _, id := range stale {
|
||||
staleSet[id] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||
}
|
||||
|
||||
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||
func TestStaleDetection_AllStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||
assert.Len(t, stale, 2)
|
||||
}
|
||||
|
||||
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||
stale := collectStaleIDs(
|
||||
map[types.ServiceID]*proto.ProxyMapping{},
|
||||
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||
)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||
// local mappings are present in the snapshot (removeMapping is never called).
|
||||
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||
|
||||
snapshotIDs := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
// This should not panic — no stale entries means removeMapping is never called.
|
||||
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||
// no local mappings.
|
||||
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||
assert.Empty(t, s.lastMappings)
|
||||
}
|
||||
|
||||
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||
|
||||
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||
// marked done only after the final InitialSyncComplete message, even when
|
||||
// the snapshot arrives in multiple batches.
|
||||
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // batch 1: no sync-complete
|
||||
{}, // batch 2: no sync-complete
|
||||
{InitialSyncComplete: true}, // batch 3: sync done
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// Simulate state left over from a previous sync.
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // post-sync empty message — must not reconcile
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := true // sync already completed in a previous stream
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2,
|
||||
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||
// stream closes before sync completes, no reconciliation occurs.
|
||||
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||
}
|
||||
|
||||
// mockErrRecvStream returns an error on the second Recv to verify
|
||||
// handleMappingStream returns without completing sync.
|
||||
type mockErrRecvStream struct {
|
||||
mockMappingStream
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &proto.GetMappingUpdateResponse{}, nil
|
||||
}
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, syncDone)
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||
}
|
||||
@@ -3327,64 +3327,10 @@ components:
|
||||
example: false
|
||||
required:
|
||||
- enabled
|
||||
ProxyTokenRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Human-readable token name
|
||||
example: "my-proxy-token"
|
||||
expires_in:
|
||||
type: integer
|
||||
minimum: 0
|
||||
description: Token expiration in seconds (0 = never expires)
|
||||
example: 0
|
||||
required:
|
||||
- name
|
||||
ProxyToken:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
last_used:
|
||||
type: string
|
||||
format: date-time
|
||||
revoked:
|
||||
type: boolean
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- created_at
|
||||
- revoked
|
||||
ProxyTokenCreated:
|
||||
type: object
|
||||
description: Returned on creation — plain_token is shown only once
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/ProxyToken'
|
||||
- type: object
|
||||
properties:
|
||||
plain_token:
|
||||
type: string
|
||||
description: The plain text token (shown only once)
|
||||
example: "nbx_abc123..."
|
||||
required:
|
||||
- plain_token
|
||||
ProxyCluster:
|
||||
type: object
|
||||
description: A proxy cluster represents a group of proxy nodes serving the same address
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique identifier of a proxy in this cluster
|
||||
example: "chlfq4q5r8kc73b0qjpg"
|
||||
address:
|
||||
type: string
|
||||
description: Cluster address used for CNAME targets
|
||||
@@ -3393,15 +3339,9 @@ components:
|
||||
type: integer
|
||||
description: Number of proxy nodes connected in this cluster
|
||||
example: 3
|
||||
self_hosted:
|
||||
type: boolean
|
||||
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- address
|
||||
- connected_proxies
|
||||
- self_hosted
|
||||
ReverseProxyDomainType:
|
||||
type: string
|
||||
description: Type of Reverse Proxy Domain
|
||||
@@ -11407,111 +11347,6 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/clusters/{clusterAddress}:
|
||||
delete:
|
||||
summary: Delete a self-hosted proxy cluster
|
||||
description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account.
|
||||
tags: [ Services ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: clusterAddress
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The address of the proxy cluster
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy cluster deleted successfully
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens:
|
||||
get:
|
||||
summary: List Proxy Tokens
|
||||
description: Returns all proxy access tokens for the account
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of proxy tokens
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ProxyToken'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Proxy Token
|
||||
description: Generate an account-scoped proxy access token for self-hosted proxy registration
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy token created (plain token shown once)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenCreated'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens/{tokenId}:
|
||||
delete:
|
||||
summary: Revoke a Proxy Token
|
||||
description: Revoke an account-scoped proxy access token
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: tokenId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of the proxy token
|
||||
responses:
|
||||
'200':
|
||||
description: Token revoked
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/services:
|
||||
get:
|
||||
summary: List all Services
|
||||
|
||||
@@ -3764,49 +3764,11 @@ type ProxyAccessLogsResponse struct {
|
||||
|
||||
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
|
||||
type ProxyCluster struct {
|
||||
// Id Unique identifier of a proxy in this cluster
|
||||
Id string `json:"id"`
|
||||
|
||||
// Address Cluster address used for CNAME targets
|
||||
Address string `json:"address"`
|
||||
|
||||
// ConnectedProxies Number of proxy nodes connected in this cluster
|
||||
ConnectedProxies int `json:"connected_proxies"`
|
||||
|
||||
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
SelfHosted bool `json:"self_hosted"`
|
||||
}
|
||||
|
||||
// ProxyToken defines model for ProxyToken.
|
||||
type ProxyToken struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenCreated defines model for ProxyTokenCreated.
|
||||
type ProxyTokenCreated struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
|
||||
// PlainToken The plain text token (shown only once)
|
||||
PlainToken string `json:"plain_token"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenRequest defines model for ProxyTokenRequest.
|
||||
type ProxyTokenRequest struct {
|
||||
// ExpiresIn Token expiration in seconds (0 = never expires)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
|
||||
// Name Human-readable token name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Resource defines model for Resource.
|
||||
@@ -5177,9 +5139,6 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
|
||||
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
|
||||
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
|
||||
|
||||
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
|
||||
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
|
||||
|
||||
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
|
||||
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user