Compare commits

..

2 Commits

Author SHA1 Message Date
mlsmaycon
134b5ce819 Merge branch 'prototype/reverse-proxy' into prototype/reverse-proxy-clusters 2026-02-05 15:01:49 +01:00
mlsmaycon
ad64ba1916 Merge prototype/reverse-proxy with proxy clustering support
Combines token-based authentication from upstream with proxy clustering:
- Session keys for JWT signing (SessionPrivateKey/SessionPublicKey)
- One-time token store for proxy authentication
- Cluster-targeted updates via SendReverseProxyUpdateToCluster
- ProxyCluster field derived from domain
- OIDC validation config support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 14:08:07 +01:00
153 changed files with 2030 additions and 14425 deletions

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features

View File

@@ -3,6 +3,12 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,9 +1,12 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {
return nil

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -13,13 +12,11 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -27,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -93,7 +89,6 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -586,40 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {

View File

@@ -1,376 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -6,9 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -29,8 +27,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -443,17 +439,6 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)

View File

@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return
return nil
}
question := query.Question[0]
qname := strings.ToLower(question.Name)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, w, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
}
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
f.writeResponse(logger, w, resp, domain, startTime)
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else {
require.NotNil(t, resp, "Expected response")
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, resp.Answer, "Response should have no answer records")
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -828,10 +828,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()

View File

@@ -2,7 +2,6 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -107,10 +107,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultAdminURL)
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -173,21 +173,12 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
},
)

View File

@@ -4,17 +4,16 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
@@ -25,51 +24,42 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&unix.RTF_CLONING != 0 {
if flags&syscall.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&unix.RTF_WASCLONED != 0 {
if flags&syscall.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,18 +4,17 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
@@ -26,46 +25,37 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

1
go.mod
View File

@@ -40,6 +40,7 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1

2
go.sum
View File

@@ -107,6 +107,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=

View File

@@ -327,60 +327,6 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
return nil
}
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return false, fmt.Errorf("failed to list connectors: %w", err)
}
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
for _, conn := range connectors {
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
if conn.ID != "local" || conn.Type != "local" {
p.logger.Info("found non-local connector", "id", conn.ID)
return true, nil
}
}
p.logger.Info("no non-local connectors found")
return false, nil
}
// DisableLocalAuth removes the local (password) connector.
// Returns an error if no other connectors are configured.
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
hasOthers, err := p.HasNonLocalConnectors(ctx)
if err != nil {
return err
}
if !hasOthers {
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
}
// Check if local connector exists
_, err = p.storage.GetConnector(ctx, "local")
if errors.Is(err, storage.ErrNotFound) {
// Already disabled
return nil
}
if err != nil {
return fmt.Errorf("failed to check local connector: %w", err)
}
// Delete the local connector
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
return fmt.Errorf("failed to delete local connector: %w", err)
}
p.logger.Info("local authentication disabled")
return nil
}
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
return ensureLocalConnector(ctx, p.storage)
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {

View File

@@ -19,8 +19,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -215,14 +213,11 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
return nil
}

View File

@@ -80,10 +80,4 @@ func init() {
migrationCmd.AddCommand(upCmd)
rootCmd.AddCommand(migrationCmd)
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
tokenCmd.AddCommand(tokenCreateCmd)
tokenCmd.AddCommand(tokenListCmd)
tokenCmd.AddCommand(tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}

View File

@@ -1,208 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!")
fmt.Printf("Token: %s\n", generated.PlainToken)
fmt.Println()
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.")
fmt.Printf("Token ID: %s\n", generated.ID)
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID)
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -1,101 +0,0 @@
package cmd
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseDuration(t *testing.T) {
tests := []struct {
name string
input string
expected time.Duration
wantErr bool
}{
{
name: "empty string returns zero",
input: "",
expected: 0,
},
{
name: "days suffix",
input: "30d",
expected: 30 * 24 * time.Hour,
},
{
name: "one day",
input: "1d",
expected: 24 * time.Hour,
},
{
name: "365 days",
input: "365d",
expected: 365 * 24 * time.Hour,
},
{
name: "hours via Go duration",
input: "24h",
expected: 24 * time.Hour,
},
{
name: "minutes via Go duration",
input: "30m",
expected: 30 * time.Minute,
},
{
name: "complex Go duration",
input: "1h30m",
expected: 90 * time.Minute,
},
{
name: "invalid day format",
input: "abcd",
wantErr: true,
},
{
name: "negative days",
input: "-1d",
wantErr: true,
},
{
name: "zero days",
input: "0d",
wantErr: true,
},
{
name: "non-numeric days",
input: "xyzd",
wantErr: true,
},
{
name: "negative Go duration",
input: "-24h",
wantErr: true,
},
{
name: "zero Go duration",
input: "0s",
wantErr: true,
},
{
name: "invalid Go duration",
input: "notaduration",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseDuration(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -174,13 +174,14 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
exposedServices := account.GetExposedServicesMap()
proxyPeers := account.GetProxyPeers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -233,7 +234,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs, exposedServices, proxyPeers)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -248,10 +249,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
}(peer)
}
@@ -327,7 +325,6 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("failed to get validated peers: %v", err)
}
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -358,7 +355,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs, account.GetExposedServicesMap(), account.GetProxyPeers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -375,10 +372,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
return nil
}
@@ -443,8 +437,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -479,7 +471,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers(), account.GetExposedServicesMap(), account.GetProxyPeers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -788,7 +780,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
},
},
},
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
@@ -851,10 +842,9 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers(), account.GetExposedServicesMap(), account.GetProxyPeers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -25,14 +25,11 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
MessageType: network_map.MessageTypeNetworkMap,
}
}}
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
@@ -48,14 +45,11 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
MessageType: network_map.MessageTypeNetworkMap,
}
}}
peersUpdater.SendUpdate(context.Background(), peer, update2)
timeout := time.After(5 * time.Second)

View File

@@ -4,19 +4,6 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
// MessageType indicates the type of update message for debouncing strategy
type MessageType int
const (
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
// These updates can be safely debounced - only the latest state matters
MessageTypeNetworkMap MessageType = iota
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
// These updates should not be dropped as they contain time-sensitive information
MessageTypeControlConfig
)
type UpdateMessage struct {
Update *proto.SyncResponse
MessageType MessageType
Update *proto.SyncResponse
}

View File

@@ -33,7 +33,7 @@ type Manager interface {
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
CreateProxyPeer(ctx context.Context, accountID string, peerKey string) error
}
type managerImpl struct {
@@ -185,7 +185,7 @@ func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, er
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
}
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string) error {
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
if err == nil && existingPeerID != "" {
// Peer already exists
@@ -194,11 +194,8 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
name := fmt.Sprintf("proxy-%s", xid.New().String())
peer := &peer.Peer{
Ephemeral: true,
ProxyMeta: peer.ProxyMeta{
Cluster: cluster,
Embedded: true,
},
Ephemeral: true,
ProxyEmbedded: true,
Name: name,
Key: peerKey,
LoginExpirationEnabled: false,

View File

@@ -13,42 +13,42 @@ import (
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
ProxyID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Method string
Host string
Path string
Duration time.Duration
StatusCode int
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
UserId string
AuthMethodUsed string
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.ID = serviceLog.GetLogId()
a.ServiceID = serviceLog.GetServiceId()
a.Timestamp = serviceLog.GetTimestamp().AsTime()
a.Method = serviceLog.GetMethod()
a.Host = serviceLog.GetHost()
a.Path = serviceLog.GetPath()
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(serviceLog.GetResponseCode())
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
func (a *AccessLogEntry) FromProto(proxyLog *proto.AccessLog) {
a.ID = proxyLog.GetLogId()
a.ProxyID = proxyLog.GetServiceId()
a.Timestamp = proxyLog.GetTimestamp().AsTime()
a.Method = proxyLog.GetMethod()
a.Host = proxyLog.GetHost()
a.Path = proxyLog.GetPath()
a.Duration = time.Duration(proxyLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(proxyLog.GetResponseCode())
a.UserId = proxyLog.GetUserId()
a.AuthMethodUsed = proxyLog.GetAuthMechanism()
a.AccountID = proxyLog.GetAccountId()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if sourceIP := proxyLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
}
}
if !serviceLog.GetAuthSuccess() {
if !proxyLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if serviceLog.GetResponseCode() >= 400 {
} else if proxyLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
}
}
@@ -88,7 +88,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
return &api.ProxyAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
ProxyId: a.ProxyID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,

View File

@@ -1,124 +0,0 @@
package accesslogs
import (
"net/http"
"strconv"
"time"
)
const (
// DefaultPageSize is the default number of records per page
DefaultPageSize = 50
// MaxPageSize is the maximum number of records allowed per page
MaxPageSize = 100
)
// AccessLogFilter holds pagination and filtering parameters for access logs
type AccessLogFilter struct {
// Page is the current page number (1-indexed)
Page int
// PageSize is the number of records per page
PageSize int
// Filtering parameters
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
Host *string // Filter by host header
Path *string // Filter by request path (supports LIKE pattern)
UserID *string // Filter by authenticated user ID
UserEmail *string // Filter by user email (requires user lookup)
UserName *string // Filter by user name (requires user lookup)
Method *string // Filter by HTTP method
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
StatusCode *int // Filter by HTTP status code
StartDate *time.Time // Filter by timestamp >= start_date
EndDate *time.Time // Filter by timestamp <= end_date
}
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
queryParams := r.URL.Query()
f.Page = 1
if pageStr := queryParams.Get("page"); pageStr != "" {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
f.Page = page
}
}
f.PageSize = DefaultPageSize
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
f.PageSize = pageSize
if f.PageSize > MaxPageSize {
f.PageSize = MaxPageSize
}
}
}
if search := queryParams.Get("search"); search != "" {
f.Search = &search
}
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
f.SourceIP = &sourceIP
}
if host := queryParams.Get("host"); host != "" {
f.Host = &host
}
if path := queryParams.Get("path"); path != "" {
f.Path = &path
}
if userID := queryParams.Get("user_id"); userID != "" {
f.UserID = &userID
}
if userEmail := queryParams.Get("user_email"); userEmail != "" {
f.UserEmail = &userEmail
}
if userName := queryParams.Get("user_name"); userName != "" {
f.UserName = &userName
}
if method := queryParams.Get("method"); method != "" {
f.Method = &method
}
if status := queryParams.Get("status"); status != "" {
f.Status = &status
}
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
f.StatusCode = &statusCode
}
}
if startDate := queryParams.Get("start_date"); startDate != "" {
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
if err == nil {
f.StartDate = &parsedStartDate
}
}
if endDate := queryParams.Get("end_date"); endDate != "" {
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
if err == nil {
f.EndDate = &parsedEndDate
}
}
}
// GetOffset calculates the database offset for pagination
func (f *AccessLogFilter) GetOffset() int {
return (f.Page - 1) * f.PageSize
}
// GetLimit returns the page size for database queries
func (f *AccessLogFilter) GetLimit() int {
return f.PageSize
}

View File

@@ -1,161 +0,0 @@
package accesslogs
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
tests := []struct {
name string
queryParams map[string]string
expectedPage int
expectedPageSize int
}{
{
name: "default values when no params provided",
queryParams: map[string]string{},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
{
name: "valid page and page_size",
queryParams: map[string]string{
"page": "2",
"page_size": "25",
},
expectedPage: 2,
expectedPageSize: 25,
},
{
name: "page_size exceeds max, should cap at MaxPageSize",
queryParams: map[string]string{
"page": "1",
"page_size": "200",
},
expectedPage: 1,
expectedPageSize: MaxPageSize,
},
{
name: "invalid page number, should use default",
queryParams: map[string]string{
"page": "invalid",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "invalid page_size, should use default",
queryParams: map[string]string{
"page": "2",
"page_size": "invalid",
},
expectedPage: 2,
expectedPageSize: DefaultPageSize,
},
{
name: "zero page number, should use default",
queryParams: map[string]string{
"page": "0",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "negative page number, should use default",
queryParams: map[string]string{
"page": "-1",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "zero page_size, should use default",
queryParams: map[string]string{
"page": "1",
"page_size": "0",
},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
q := req.URL.Query()
for key, value := range tt.queryParams {
q.Set(key, value)
}
req.URL.RawQuery = q.Encode()
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
})
}
}
func TestAccessLogFilter_GetOffset(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
expectedOffset int
}{
{
name: "first page",
page: 1,
pageSize: 50,
expectedOffset: 0,
},
{
name: "second page",
page: 2,
pageSize: 50,
expectedOffset: 50,
},
{
name: "third page with page size 25",
page: 3,
pageSize: 25,
expectedOffset: 50,
},
{
name: "page 10 with page size 10",
page: 10,
pageSize: 10,
expectedOffset: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := &AccessLogFilter{
Page: tt.page,
PageSize: tt.pageSize,
}
offset := filter.GetOffset()
assert.Equal(t, tt.expectedOffset, offset)
})
}
}
func TestAccessLogFilter_GetLimit(t *testing.T) {
filter := &AccessLogFilter{
Page: 2,
PageSize: 25,
}
limit := filter.GetLimit()
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
}

View File

@@ -6,5 +6,5 @@ import (
type Manager interface {
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
GetAllAccessLogs(ctx context.Context, accountID, userID string) ([]*AccessLogEntry, error)
}

View File

@@ -30,10 +30,7 @@ func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
return
}
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
logs, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -44,21 +41,5 @@ func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
apiLogs = append(apiLogs, *log.ToAPIResponse())
}
response := &api.ProxyAccessLogsResponse{
Data: apiLogs,
Page: filter.Page,
PageSize: filter.PageSize,
TotalRecords: int(totalCount),
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
}
util.WriteJSONObject(r.Context(), w, response)
}
// getTotalPageCount calculates the total number of pages
func getTotalPageCount(totalCount, pageSize int) int {
if pageSize <= 0 {
return 0
}
return (totalCount + pageSize - 1) / pageSize
util.WriteJSONObject(r.Context(), w, apiLogs)
}

View File

@@ -2,7 +2,6 @@ package manager
import (
"context"
"strings"
log "github.com/sirupsen/logrus"
@@ -44,11 +43,11 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
"proxy_id": logEntry.ProxyID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
}).Errorf("failed to save access log: %v", err)
return err
}
@@ -56,53 +55,20 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
return nil
}
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
// GetAllAccessLogs retrieves all access logs for an account
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string) ([]*accesslogs.AccessLogEntry, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
return nil, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
}
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
logs, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, 0, err
return nil, err
}
return logs, totalCount, nil
}
// resolveUserFilters converts user email/name filters to user ID filter
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
if filter.UserEmail == nil && filter.UserName == nil {
return nil
}
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
var matchingUserIDs []string
for _, user := range users {
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
continue
}
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
}
}
if len(matchingUserIDs) > 0 {
filter.UserID = &matchingUserIDs[0]
}
return nil
return logs, nil
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -39,16 +40,12 @@ func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
}
func domainToApi(d *Domain) api.ReverseProxyDomain {
resp := api.ReverseProxyDomain{
return api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster
}
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
@@ -85,7 +82,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
return
}
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, req.Domain, req.TargetCluster)
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, req.Domain)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -19,12 +19,11 @@ const (
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
TargetCluster string // The proxy cluster this domain should be validated against
Type domainType `gorm:"-"`
Validated bool
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
Type domainType `gorm:"-"`
Validated bool
}
type store interface {
@@ -33,7 +32,7 @@ type store interface {
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *Domain) (*Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
@@ -86,43 +85,31 @@ func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, e
// Add custom domains.
for _, domain := range domains {
ret = append(ret, &Domain{
ID: domain.ID,
Domain: domain.Domain,
AccountID: accountID,
TargetCluster: domain.TargetCluster,
Type: TypeCustom,
Validated: domain.Validated,
ID: domain.ID,
Domain: domain.Domain,
AccountID: accountID,
Type: TypeCustom,
Validated: domain.Validated,
})
}
return ret, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, domainName, targetCluster string) (*Domain, error) {
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
clusterValid = true
break
}
}
if !clusterValid {
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
}
// Attempt an initial validation against the specified cluster only
func (m Manager) CreateDomain(ctx context.Context, accountID, domainName string) (*Domain, error) {
// Attempt an initial validation; however, a failure is still acceptable for creation
// because the user may not yet have configured their DNS records, or the DNS update
// has not yet reached the servers that are queried by the validation resolver.
var validated bool
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
if m.validator.IsValid(ctx, domainName, m.proxyURLAllowList()) {
validated = true
}
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, validated)
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
return d, nil
}
@@ -149,25 +136,15 @@ func (m Manager) ValidateDomain(accountID, domainID string) {
return
}
// Validate only against the domain's target cluster
targetCluster := d.TargetCluster
if targetCluster == "" {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Warn("domain has no target cluster set, skipping validation")
return
}
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Info("validating domain against target cluster")
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"proxyAllowList": allowList,
}).Info("validating domain against proxy allow list")
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
if m.validator.IsValid(context.Background(), d.Domain, allowList) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
@@ -184,11 +161,11 @@ func (m Manager) ValidateDomain(accountID, domainID string) {
}
} else {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Warn("domain validation failed - CNAME does not match target cluster")
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"proxyAllowList": allowList,
}).Warn("domain validation failed - CNAME does not match any connected proxy")
}
}

View File

@@ -5,17 +5,11 @@ import (
)
type Manager interface {
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*ReverseProxy, error)
GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*ReverseProxy, error)
CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *ReverseProxy) (*ReverseProxy, error)
DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status ProxyStatus) error
}

View File

@@ -20,148 +20,148 @@ type handler struct {
manager reverseproxy.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
// RegisterEndpoints registers all reverse proxy HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
domainRouter := router.PathPrefix("/services").Subrouter()
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domain.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies", h.getAllReverseProxies).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies", h.createReverseProxy).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.getReverseProxy).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.updateReverseProxy).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/{proxyId}", h.deleteReverseProxy).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllReverseProxies(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
allReverseProxies, err := h.manager.GetAllReverseProxies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiServices := make([]*api.Service, 0, len(allServices))
for _, service := range allServices {
apiServices = append(apiServices, service.ToAPIResponse())
apiReverseProxies := make([]*api.ReverseProxy, 0, len(allReverseProxies))
for _, reverseProxy := range allReverseProxies {
apiReverseProxies = append(apiReverseProxies, reverseProxy.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiServices)
util.WriteJSONObject(r.Context(), w, apiReverseProxies)
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
func (h *handler) createReverseProxy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ServiceRequest
var req api.ReverseProxyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
if err = reverseProxy.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
createdReverseProxy, err := h.manager.CreateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, createdReverseProxy.ToAPIResponse())
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
func (h *handler) getReverseProxy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
return
}
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
reverseProxy, err := h.manager.GetReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, reverseProxy.ToAPIResponse())
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateReverseProxy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
return
}
var req api.ServiceRequest
var req api.ReverseProxyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
service := new(reverseproxy.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
reverseProxy := new(reverseproxy.ReverseProxy)
reverseProxy.ID = reverseProxyID
reverseProxy.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
if err = reverseProxy.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
updatedReverseProxy, err := h.manager.UpdateReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, updatedReverseProxy.ToAPIResponse())
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteReverseProxy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
reverseProxyID := mux.Vars(r)["proxyId"]
if reverseProxyID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "reverse proxy ID is required"), w)
return
}
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
if err := h.manager.DeleteReverseProxy(r.Context(), userAuth.AccountId, userAuth.UserId, reverseProxyID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -19,8 +19,6 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, domain string) (string, error)
@@ -31,23 +29,23 @@ type managerImpl struct {
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
tokenStore *nbgrpc.OneTimeTokenStore
clusterDeriver ClusterDeriver
tokenStore *nbgrpc.OneTimeTokenStore
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore, clusterDeriver ClusterDeriver) reverseproxy.Manager {
// NewManager creates a new reverse proxy manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver, tokenStore *nbgrpc.OneTimeTokenStore) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
tokenStore: tokenStore,
clusterDeriver: clusterDeriver,
tokenStore: tokenStore,
}
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userID string) ([]*reverseproxy.ReverseProxy, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -56,58 +54,10 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -116,19 +66,10 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -139,46 +80,40 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, service.Domain)
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxies", reverseProxy.Domain)
}
}
service.AccountID = accountID
service.ProxyCluster = proxyCluster
service.InitNewRecord()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
authConfig := reverseProxy.Auth
reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, proxyCluster, reverseProxy.Targets, reverseProxy.Enabled)
reverseProxy.Auth = authConfig
// Generate session JWT signing keys
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
reverseProxy.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("failed to check existing reverse proxy: %w", err)
}
}
if existingService != nil {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
if existingReverseProxy != nil {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
}
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
if err = transaction.CreateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("failed to create reverse proxy: %w", err)
}
return nil
@@ -187,26 +122,19 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, err
}
token, err := m.tokenStore.GenerateToken(accountID, service.ID, 5*time.Minute)
token, err := m.tokenStore.GenerateToken(accountID, reverseProxy.ID, 5*time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyCreated, reverseProxy.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return reverseProxy, nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -217,67 +145,42 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
var oldCluster string
var domainChanged bool
var serviceEnabledChanged bool
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID)
if err != nil {
return err
}
oldCluster = existingService.ProxyCluster
oldCluster = existingReverseProxy.ProxyCluster
if existingService.Domain != service.Domain {
if existingReverseProxy.Domain != reverseProxy.Domain {
domainChanged = true
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("check existing service: %w", err)
return fmt.Errorf("check existing reverse proxy: %w", err)
}
}
if conflictService != nil && conflictService.ID != service.ID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", reverseProxy.Domain)
}
service.ProxyCluster = newCluster
reverseProxy.ProxyCluster = newCluster
}
} else {
service.ProxyCluster = existingService.ProxyCluster
reverseProxy.ProxyCluster = existingReverseProxy.ProxyCluster
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
reverseProxy.Meta = existingReverseProxy.Meta
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
serviceEnabledChanged = existingService.Enabled != service.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("update reverse proxy: %w", err)
}
return nil
@@ -286,59 +189,20 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
oidcConfig := m.proxyGRPCServer.GetOIDCValidationConfig()
if domainChanged && oldCluster != reverseProxy.ProxyCluster {
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", oidcConfig), oldCluster)
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, "", oidcConfig), reverseProxy.ProxyCluster)
} else {
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, "", oidcConfig), reverseProxy.ProxyCluster)
}
token, err := m.tokenStore.GenerateToken(accountID, service.ID, 5*time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
}
switch {
case domainChanged && oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
case !service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
case service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return reverseProxy, nil
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -347,16 +211,16 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
var reverseProxy *reverseproxy.ReverseProxy
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
reverseProxy, err = transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
if err != nil {
return err
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
if err = transaction.DeleteReverseProxy(ctx, accountID, reverseProxyID); err != nil {
return fmt.Errorf("failed to delete reverse proxy: %w", err)
}
return nil
@@ -365,144 +229,46 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
return fmt.Errorf("failed to get reverse proxy: %w", err)
}
service.Meta.CertificateIssuedAt = time.Now()
proxy.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
// SetStatus updates the status of the reverse proxy (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
proxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxyID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
return fmt.Errorf("failed to get reverse proxy: %w", err)
}
service.Meta.Status = string(status)
proxy.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
if err = transaction.UpdateReverseProxy(ctx, proxy); err != nil {
return fmt.Errorf("failed to update reverse proxy status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}

View File

@@ -2,18 +2,15 @@ package reverseproxy
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/netbirdio/netbird/util/crypt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -36,23 +33,18 @@ const (
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
TargetTypePeer = "peer"
TargetTypeResource = "resource"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Host string `json:"host"`
Port int `json:"port"`
Protocol string `json:"protocol"`
TargetId string `json:"target_id"`
TargetType string `json:"target_type"`
Enabled bool `json:"enabled"`
}
type PasswordAuthConfig struct {
@@ -76,35 +68,6 @@ type AuthConfig struct {
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
@@ -112,147 +75,127 @@ type OIDCValidationConfig struct {
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
type ReverseProxyMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type Service struct {
type ReverseProxy struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []Target, enabled bool) *ReverseProxy {
return &ReverseProxy{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
Meta: ReverseProxyMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
},
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
authConfig := api.ReverseProxyAuthConfig{}
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
if r.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
Enabled: r.Auth.PasswordAuth.Enabled,
Password: r.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
if r.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
Enabled: r.Auth.PinAuth.Enabled,
Pin: r.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
if r.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
Enabled: r.Auth.BearerAuth.Enabled,
DistributionGroups: &r.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
for _, target := range r.Targets {
apiTargets = append(apiTargets, api.ReverseProxyTarget{
Path: target.Path,
Host: &target.Host,
Host: target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
Protocol: api.ReverseProxyTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
TargetType: api.ReverseProxyTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
meta := api.ReverseProxyMeta{
CreatedAt: r.Meta.CreatedAt,
Status: api.ReverseProxyMetaStatus(r.Meta.Status),
}
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
if !r.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
resp := &api.ReverseProxy{
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
Targets: apiTargets,
Enabled: r.Enabled,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
if r.ProxyCluster != "" {
resp.ProxyCluster = &r.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(r.Targets))
for _, target := range r.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: path,
}
if target.Port > 0 {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
@@ -260,32 +203,30 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
SessionKey: r.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
if r.Auth.PinAuth != nil && r.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Type: operationToProtoType(operation),
Id: r.ID,
Domain: r.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: r.AccountID,
}
}
@@ -303,54 +244,36 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID string) {
r.Name = req.Name
r.Domain = req.Domain
r.AccountID = accountID
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
targets := make([]Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
targets = append(targets, Target{
Path: apiTarget.Path,
Host: apiTarget.Host,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
})
}
s.Targets = targets
r.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
r.Enabled = req.Enabled
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
r.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
r.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
@@ -363,81 +286,60 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
r.Auth.BearerAuth = bearerAuth
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
func (r *ReverseProxy) Validate() error {
if r.Name == "" {
return errors.New("reverse proxy name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
if len(r.Name) > 255 {
return errors.New("reverse proxy name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
if r.Domain == "" {
return errors.New("reverse proxy domain is required")
}
if len(s.Targets) == 0 {
if len(r.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
func (r *ReverseProxy) EventMeta() map[string]any {
return map[string]any{"name": r.Name, "domain": r.Domain, "proxy_cluster": r.ProxyCluster}
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]Target, len(r.Targets))
copy(targets, r.Targets)
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
return &ReverseProxy{
ID: r.ID,
AccountID: r.AccountID,
Name: r.Name,
Domain: r.Domain,
ProxyCluster: r.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
Enabled: r.Enabled,
Auth: r.Auth,
Meta: r.Meta,
SessionPrivateKey: r.SessionPrivateKey,
SessionPublicKey: r.SessionPublicKey,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
if r.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
if err != nil {
return err
}
@@ -446,14 +348,14 @@ func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
if r.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
if err != nil {
return err
}

View File

@@ -1,405 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"net/netip"
"slices"
"strings"
"time"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
@@ -122,13 +123,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
s.proxyAuthClose = proxyAuthClose
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.Config.HttpConfig.LetsEncryptDomain != "" {
@@ -163,7 +162,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
proxyService := nbgrpc.NewProxyServiceServer(s.Store(), s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
})
@@ -173,12 +172,18 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
return Create(s, func() nbgrpc.ProxyOIDCConfig {
// TODO: this is weird, double check
// Build callback URL - this should be the management server's callback endpoint
// For embedded IdP, derive from issuer. For external, use a configured value or derive from issuer.
// The callback URL should be registered in the IdP's allowed redirect URIs for the dashboard client.
callbackURL := strings.TrimSuffix(s.Config.HttpConfig.AuthIssuer, "/oauth2")
callbackURL = callbackURL + "/api/oauth/callback"
return nbgrpc.ProxyOIDCConfig{
Issuer: s.Config.HttpConfig.AuthIssuer,
// todo: double check auth clientID value
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
Issuer: s.Config.HttpConfig.AuthIssuer,
ClientID: "netbird-dashboard", // Reuse dashboard client
Scopes: []string{"openid", "profile", "email"},
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
CallbackURL: callbackURL,
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
Audience: s.Config.HttpConfig.AuthAudience,
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,

View File

@@ -100,8 +100,6 @@ type HttpServerConfig struct {
CertFile string
// CertKey is the location of the certificate private key
CertKey string
// AuthClientID is the client id used for proxy SSO auth
AuthClientID string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
@@ -119,8 +117,6 @@ type HttpServerConfig struct {
IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
// AuthCallbackDomain contains the callback domain
AuthCallbackURL string
}
// Host represents a Netbird host (e.g. STUN, TURN, Signal)

View File

@@ -72,14 +72,7 @@ func (s *BaseServer) UsersManager() users.Manager {
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
idpConfig := settings.IdpConfig{}
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpConfig.EmbeddedIdpEnabled = true
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
}
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
})
}
@@ -101,11 +94,6 @@ func (s *BaseServer) AccountManager() account.Manager {
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
})
return accountManager
})
}
@@ -162,7 +150,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
})
}
@@ -192,7 +180,8 @@ func (s *BaseServer) RecordsManager() records.Manager {
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore(), s.ReverseProxyDomainManager())
domainMgr := s.ReverseProxyDomainManager()
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), domainMgr, s.ProxyTokenStore())
})
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -20,7 +21,6 @@ import (
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util/wsproxy"
@@ -58,8 +58,6 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
proxyAuthClose func()
listener net.Listener
certManager *autocert.Manager
update *version.Update
@@ -217,10 +215,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
}
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {

View File

@@ -13,7 +13,7 @@ import (
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// a reverse proxy is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
@@ -24,10 +24,10 @@ type OneTimeTokenStore struct {
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ServiceID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
ReverseProxyID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
@@ -48,10 +48,10 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// GenerateToken creates a new cryptographically secure one-time token
// with the specified TTL. The token is associated with a specific
// accountID and serviceID for validation purposes.
// accountID and reverseProxyID for validation purposes.
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
func (s *OneTimeTokenStore) GenerateToken(accountID, reverseProxyID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
@@ -65,20 +65,20 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
ReverseProxyID: reverseProxyID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl)
reverseProxyID, accountID, ttl)
return token, nil
}
// ValidateAndConsume verifies the token against the provided accountID and
// serviceID, checks expiration, and then deletes it to enforce single-use.
// reverseProxyID, checks expiration, and then deletes it to enforce single-use.
//
// This method uses constant-time comparison to prevent timing attacks.
//
@@ -87,14 +87,14 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Token has expired
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, reverseProxyID string) error {
s.mu.Lock()
defer s.mu.Unlock()
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
serviceID, accountID)
reverseProxyID, accountID)
return fmt.Errorf("invalid token")
}
@@ -102,7 +102,7 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID strin
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
reverseProxyID, accountID)
return fmt.Errorf("token expired")
}
@@ -113,18 +113,18 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID strin
return fmt.Errorf("account ID mismatch")
}
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
// Validate reverse proxy ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ReverseProxyID), []byte(reverseProxyID)) != 1 {
log.Warnf("Token validation failed: reverse proxy ID mismatch (expected: %s, got: %s)",
metadata.ReverseProxyID, reverseProxyID)
return fmt.Errorf("reverse proxy ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
log.Infof("Token validated and consumed for proxy %s in account %s",
serviceID, accountID)
reverseProxyID, accountID)
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"errors"
@@ -25,10 +26,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/store"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -43,6 +42,17 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
type reverseProxyStore interface {
GetReverseProxies(ctx context.Context, lockStrength store.LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength store.LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
GetReverseProxyByID(ctx context.Context, lockStrength store.LockingStrength, accountID string, serviceID string) (*reverseproxy.ReverseProxy, error)
}
type reverseProxyManager interface {
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error
}
// ClusterInfo contains information about a proxy cluster.
type ClusterInfo struct {
Address string
@@ -62,18 +72,18 @@ type ProxyServiceServer struct {
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
// Store of reverse proxies
reverseProxyStore reverseProxyStore
// Manager for access logs
accessLogManager accesslogs.Manager
// Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager
reverseProxyManager reverseProxyManager
// Manager for peers
peersManager peers.Manager
// Manager for users
usersManager users.Manager
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
@@ -95,19 +105,19 @@ type proxyConnection struct {
mu sync.RWMutex
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
// NewProxyServiceServer creates a new proxy service server
func NewProxyServiceServer(store reverseProxyStore, accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager) *ProxyServiceServer {
return &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
peersManager: peersManager,
usersManager: usersManager,
updatesChan: make(chan *proto.ProxyMapping, 100),
reverseProxyStore: store,
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
peersManager: peersManager,
}
}
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
func (s *ProxyServiceServer) SetProxyManager(manager reverseProxyManager) {
s.reverseProxyManager = manager
}
@@ -176,40 +186,40 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
// Only reverse proxies matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
return fmt.Errorf("get reverse proxies from store: %w", err)
}
proxyClusterAddr := extractClusterAddr(conn.address)
for _, service := range services {
if !service.Enabled {
for _, rp := range reverseProxies {
if !rp.Enabled {
continue
}
if service.ProxyCluster != "" && proxyClusterAddr != "" && service.ProxyCluster != proxyClusterAddr {
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
continue
}
// Generate one-time authentication token for each service in the snapshot
// Generate one-time authentication token for each proxy in the snapshot
// Tokens are not persistent on the proxy, so we need to generate new ones on reconnection
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(rp.AccountID, rp.ID, 5*time.Minute)
if err != nil {
log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
"proxy": rp.Name,
"account": rp.AccountID,
}).WithError(err).Error("Failed to generate auth token for snapshot")
continue
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
rp.ToProtoMapping(
reverseproxy.Create,
token,
s.GetOIDCValidationConfig(),
),
@@ -259,27 +269,19 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
fields := log.Fields{
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"source_ip": accessLog.GetSourceIp(),
}
if mechanism := accessLog.GetAuthMechanism(); mechanism != "" {
fields["auth_mechanism"] = mechanism
}
if userID := accessLog.GetUserId(); userID != "" {
fields["user_id"] = userID
}
if !accessLog.GetAuthSuccess() {
fields["auth_success"] = false
}
log.WithFields(fields).Debugf("%s %s %d (%dms)",
accessLog.GetMethod(),
accessLog.GetPath(),
accessLog.GetResponseCode(),
accessLog.GetDurationMs(),
)
log.WithFields(log.Fields{
"reverse_proxy_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
"host": accessLog.GetHost(),
"path": accessLog.GetPath(),
"method": accessLog.GetMethod(),
"response_code": accessLog.GetResponseCode(),
"duration_ms": accessLog.GetDurationMs(),
"source_ip": accessLog.GetSourceIp(),
"auth_mechanism": accessLog.GetAuthMechanism(),
"user_id": accessLog.GetUserId(),
"auth_success": accessLog.GetAuthSuccess(),
}).Debug("Access log from proxy")
logEntry := &accesslogs.AccessLogEntry{}
logEntry.FromProto(accessLog)
@@ -292,18 +294,18 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
return &proto.SendAccessLogResponse{}, nil
}
// SendServiceUpdate broadcasts a service update to all connected proxy servers.
// Management should call this when services are created/updated/removed
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
// Send it to all connected proxy servers
log.Debugf("Broadcasting service update to all connected proxy servers")
// SendReverseProxyUpdate broadcasts a reverse proxy update to all connected proxies.
// Management should call this when reverse proxies are created/updated/removed
func (s *ProxyServiceServer) SendReverseProxyUpdate(update *proto.ProxyMapping) {
// Send it to all connected proxies
log.Debugf("Broadcasting reverse proxy update to all connected proxies")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
select {
case conn.sendChan <- update:
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
log.Debugf("Sent reverse proxy update with id %s to proxy %s", update.Id, conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
log.Warnf("Failed to send reverse proxy update to proxy %s (channel full)", conn.proxyID)
}
return true
})
@@ -366,11 +368,11 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" {
s.SendServiceUpdate(update)
s.SendReverseProxyUpdate(update)
return
}
@@ -380,16 +382,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi
return
}
log.Debugf("Sending service update to cluster %s", clusterAddr)
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
select {
case conn.sendChan <- update:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
@@ -424,10 +426,10 @@ func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId())
if err != nil {
// TODO: log the error
return nil, status.Errorf(codes.FailedPrecondition, "failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
}
var authenticated bool
@@ -435,51 +437,33 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
var method proxyauth.Method
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
auth := service.Auth.PinAuth
auth := proxy.Auth.PinAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
break
}
err = argon2id.Verify(v.Pin.GetPin(), auth.Pin)
if err != nil {
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
log.WithContext(ctx).Tracef("PIN authentication failed: invalid PIN")
} else {
log.WithContext(ctx).Errorf("PIN authentication error: %v", err)
}
break
}
authenticated = true
authenticated = subtle.ConstantTimeCompare([]byte(auth.Pin), []byte(v.Pin.GetPin())) == 1
userId = "pin-user"
method = proxyauth.MethodPIN
case *proto.AuthenticateRequest_Password:
auth := service.Auth.PasswordAuth
auth := proxy.Auth.PasswordAuth
if auth == nil || !auth.Enabled {
// TODO: log
// Break here and use the default authenticated == false.
break
}
err = argon2id.Verify(v.Password.GetPassword(), auth.Password)
if err != nil {
if errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
log.WithContext(ctx).Tracef("Password authentication failed: invalid password")
} else {
log.WithContext(ctx).Errorf("Password authentication error: %v", err)
}
break
}
authenticated = true
authenticated = subtle.ConstantTimeCompare([]byte(auth.Password), []byte(v.Password.GetPassword())) == 1
userId = "password-user"
method = proxyauth.MethodPassword
}
var token string
if authenticated && service.SessionPrivateKey != "" {
if authenticated && proxy.SessionPrivateKey != "" {
token, err = sessionkey.SignToken(
service.SessionPrivateKey,
proxy.SessionPrivateKey,
userId,
service.Domain,
proxy.Domain,
method,
proxyauth.DefaultSessionExpiry,
)
@@ -498,45 +482,45 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
// SendStatusUpdate handles status updates from proxy clients
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
reverseProxyID := req.GetReverseProxyId()
protoStatus := req.GetStatus()
certificateIssued := req.GetCertificateIssued()
log.WithFields(log.Fields{
"service_id": serviceID,
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"status": protoStatus,
"certificate_issued": certificateIssued,
"error_message": req.GetErrorMessage(),
}).Debug("Status update from proxy server")
}).Debug("Status update from proxy")
if serviceID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "service_id and account_id are required")
if reverseProxyID == "" || accountID == "" {
return nil, status.Errorf(codes.InvalidArgument, "reverse_proxy_id and account_id are required")
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, reverseProxyID); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "failed to update certificate timestamp: %v", err)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
}).Info("Certificate issued timestamp updated")
}
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set service status")
return nil, status.Errorf(codes.Internal, "failed to update service status: %v", err)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, reverseProxyID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("Failed to set proxy status")
return nil, status.Errorf(codes.Internal, "failed to update proxy status: %v", err)
}
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
"status": internalStatus,
}).Info("Service status updated")
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
"status": internalStatus,
}).Info("Proxy status updated")
return &proto.SendStatusUpdateResponse{}, nil
}
@@ -563,30 +547,28 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStat
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
serviceID := req.GetServiceId()
reverseProxyID := req.GetReverseProxyId()
accountID := req.GetAccountId()
token := req.GetToken()
cluster := req.GetCluster()
key := req.WireguardPublicKey
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
"cluster": cluster,
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
}).Debug("CreateProxyPeer request received")
if serviceID == "" || accountID == "" || token == "" {
if reverseProxyID == "" || accountID == "" || token == "" {
log.Warn("CreateProxyPeer: missing required fields")
return &proto.CreateProxyPeerResponse{
Success: false,
ErrorMessage: strPtr("missing required fields: service_id, account_id, and token are required"),
ErrorMessage: strPtr("missing required fields: reverse_proxy_id, account_id, and token are required"),
}, nil
}
if err := s.tokenStore.ValidateAndConsume(token, accountID, serviceID); err != nil {
if err := s.tokenStore.ValidateAndConsume(token, accountID, reverseProxyID); err != nil {
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
}).WithError(err).Warn("CreateProxyPeer: token validation failed")
return &proto.CreateProxyPeerResponse{
Success: false,
@@ -594,11 +576,11 @@ func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.Cre
}, status.Errorf(codes.Unauthenticated, "token validation failed: %v", err)
}
err := s.peersManager.CreateProxyPeer(ctx, accountID, key, cluster)
err := s.peersManager.CreateProxyPeer(ctx, accountID, key)
if err != nil {
log.WithFields(log.Fields{
"service_id": serviceID,
"account_id": accountID,
"reverse_proxy_id": reverseProxyID,
"account_id": accountID,
}).WithError(err).Error("CreateProxyPeer: failed to create proxy peer")
return &proto.CreateProxyPeerResponse{
Success: false,
@@ -622,22 +604,22 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
// TODO: log
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId())
if err != nil {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to get services from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
}
var found bool
for _, service := range services {
if service.Domain == redirectURL.Hostname() {
for _, proxy := range proxies {
if proxy.Domain == redirectURL.Hostname() {
found = true
break
}
}
if !found {
// TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "service not found in store")
return nil, status.Errorf(codes.FailedPrecondition, "reverse proxy not found in store")
}
provider, err := oidc.NewProvider(ctx, s.oidcConfig.Issuer)
@@ -728,229 +710,32 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
// Find the proxy by domain to get its signing key
proxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
return "", fmt.Errorf("get reverse proxies: %w", err)
}
var service *reverseproxy.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
var proxy *reverseproxy.ReverseProxy
for _, p := range proxies {
if p.Domain == domain {
proxy = p
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
if proxy == nil {
return "", fmt.Errorf("reverse proxy not found for domain: %s", domain)
}
if service.SessionPrivateKey == "" {
if proxy.SessionPrivateKey == "" {
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
return sessionkey.SignToken(
service.SessionPrivateKey,
proxy.SessionPrivateKey,
userID,
domain,
method,
proxyauth.DefaultSessionExpiry,
)
}
// ValidateUserGroupAccess checks if a user has access to a service.
// It looks up the service within the user's account only, then optionally checks
// group membership if BearerAuth with DistributionGroups is configured.
func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain, userID string) error {
user, err := s.usersManager.GetUser(ctx, userID)
if err != nil {
return fmt.Errorf("user not found: %s", userID)
}
service, err := s.getAccountServiceByDomain(ctx, user.AccountID, domain)
if err != nil {
return err
}
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}
allowedSet := make(map[string]bool, len(allowedGroups))
for _, groupID := range allowedGroups {
allowedSet[groupID] = true
}
for _, groupID := range user.AutoGroups {
if allowedSet[groupID] {
log.WithFields(log.Fields{
"user_id": user.Id,
"group_id": groupID,
"domain": domain,
}).Debug("User granted access via group membership")
return nil
}
}
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain %s in account %s", domain, accountID)
}
// ValidateSession validates a session token and checks if the user has access to the domain.
func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.ValidateSessionRequest) (*proto.ValidateSessionResponse, error) {
domain := req.GetDomain()
sessionToken := req.GetSessionToken()
if domain == "" || sessionToken == "" {
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "missing domain or session_token",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: service not found")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Error("ValidateSession: decode public key")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_service_config",
}, nil
}
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"error": err.Error(),
}).Debug("ValidateSession: invalid session token")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "invalid_token",
}, nil
}
user, err := s.usersManager.GetUser(ctx, userID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: user not found")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "user_not_found",
}, nil
}
if user.AccountID != service.AccountID {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"user_account": user.AccountID,
"service_account": service.AccountID,
}).Debug("ValidateSession: user account mismatch")
return &proto.ValidateSessionResponse{
Valid: false,
DeniedReason: "account_mismatch",
}, nil
}
if err := s.checkGroupAccess(service, user); err != nil {
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
}, nil
}
log.WithFields(log.Fields{
"domain": domain,
"user_id": userID,
"email": user.Email,
}).Debug("ValidateSession: access granted")
return &proto.ValidateSessionResponse{
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
}, nil
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}
allowedGroups := service.Auth.BearerAuth.DistributionGroups
if len(allowedGroups) == 0 {
return nil
}
allowedSet := make(map[string]bool, len(allowedGroups))
for _, groupID := range allowedGroups {
allowedSet[groupID] = true
}
for _, groupID := range user.AutoGroups {
if allowedSet[groupID] {
return nil
}
}
return fmt.Errorf("user not in allowed groups")
}

View File

@@ -1,234 +0,0 @@
package grpc
import (
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
lastUsedUpdateInterval = time.Minute
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
lastUsedCleanupInterval = 2 * time.Minute
)
type proxyTokenContextKey struct{}
// ProxyTokenContextKey is the typed key used to store validated token info in context.
var ProxyTokenContextKey = proxyTokenContextKey{}
// proxyTokenID identifies a proxy access token by its database ID.
type proxyTokenID = string
// proxyTokenStore defines the store interface needed for token validation
type proxyTokenStore interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
}
// proxyAuthInterceptor holds state for proxy authentication interceptors.
type proxyAuthInterceptor struct {
store proxyTokenStore
failureLimiter *authFailureLimiter
// lastUsedMu protects lastUsedTimes
lastUsedMu sync.Mutex
lastUsedTimes map[proxyTokenID]time.Time
cancel context.CancelFunc
}
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
ctx, cancel := context.WithCancel(context.Background())
i := &proxyAuthInterceptor{
store: tokenStore,
failureLimiter: newAuthFailureLimiter(),
lastUsedTimes: make(map[proxyTokenID]time.Time),
cancel: cancel,
}
go i.lastUsedCleanupLoop(ctx)
return i
}
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
// The returned close function must be called on shutdown to stop background goroutines.
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
interceptor := newProxyAuthInterceptor(tokenStore)
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(ctx, req)
}
token, err := interceptor.validateProxyToken(ctx)
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
return nil, err
}
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
return handler(ctx, req)
}
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(srv, ss)
}
token, err := interceptor.validateProxyToken(ss.Context())
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
return err
}
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
wrapped := &wrappedServerStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, wrapped)
}
return unary, stream, interceptor.close
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
}
token, err := i.doValidateProxyToken(ctx)
if err != nil {
if clientIP != "" {
i.failureLimiter.recordFailure(clientIP)
}
return nil, err
}
i.maybeUpdateLastUsed(ctx, token.ID)
return token, nil
}
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
}
authValues := md.Get("authorization")
if len(authValues) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
}
authValue := authValues[0]
if !strings.HasPrefix(authValue, "Bearer ") {
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
if err := plainToken.Validate(); err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
}
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}
return token, nil
}
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
now := time.Now()
i.lastUsedMu.Lock()
lastUpdate, exists := i.lastUsedTimes[tokenID]
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
i.lastUsedMu.Unlock()
return
}
i.lastUsedTimes[tokenID] = now
i.lastUsedMu.Unlock()
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
}
}
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
ticker := time.NewTicker(lastUsedCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
i.cleanupStaleLastUsed()
case <-ctx.Done():
return
}
}
}
// cleanupStaleLastUsed removes entries older than 2x the update interval.
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
i.lastUsedMu.Lock()
defer i.lastUsedMu.Unlock()
now := time.Now()
staleThreshold := 2 * lastUsedUpdateInterval
for id, lastUpdate := range i.lastUsedTimes {
if now.Sub(lastUpdate) > staleThreshold {
delete(i.lastUsedTimes, id)
}
}
}
func (i *proxyAuthInterceptor) close() {
i.cancel()
i.failureLimiter.stop()
}
// GetProxyTokenFromContext retrieves the validated proxy token from the context
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
if !ok {
return nil
}
return token
}
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

View File

@@ -1,134 +0,0 @@
package grpc
import (
"context"
"net"
"sync"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"golang.org/x/time/rate"
"google.golang.org/grpc/peer"
)
const (
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
proxyAuthFailureBurst = 5
// proxyAuthLimiterCleanup is how often stale limiters are removed.
proxyAuthLimiterCleanup = 5 * time.Minute
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
proxyAuthLimiterTTL = 15 * time.Minute
)
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
// One token every 12 seconds = 5 per minute.
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
// clientIP identifies a client by its IP address for rate limiting purposes.
type clientIP = string
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
type authFailureLimiter struct {
mu sync.Mutex
limiters map[clientIP]*limiterEntry
failureRate rate.Limit
cancel context.CancelFunc
}
func newAuthFailureLimiter() *authFailureLimiter {
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
}
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
ctx, cancel := context.WithCancel(context.Background())
l := &authFailureLimiter{
limiters: make(map[clientIP]*limiterEntry),
failureRate: failureRate,
cancel: cancel,
}
go l.cleanupLoop(ctx)
return l
}
// isLimited returns true if the given IP has exhausted its failure budget.
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
l.mu.Lock()
defer l.mu.Unlock()
entry, exists := l.limiters[ip]
if !exists {
return false
}
return entry.limiter.Tokens() < 1
}
// recordFailure consumes a token from the rate limiter for the given IP.
func (l *authFailureLimiter) recordFailure(ip clientIP) {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
entry, exists := l.limiters[ip]
if !exists {
entry = &limiterEntry{
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
}
l.limiters[ip] = entry
}
entry.lastAccess = now
entry.limiter.Allow()
}
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(proxyAuthLimiterCleanup)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.cleanup()
case <-ctx.Done():
return
}
}
}
func (l *authFailureLimiter) cleanup() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, entry := range l.limiters {
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
delete(l.limiters, ip)
}
}
}
func (l *authFailureLimiter) stop() {
l.cancel()
}
// peerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}
return ""
}

View File

@@ -1,98 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -1,381 +0,0 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "reverse proxy not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "reverse proxy not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account reverse proxies",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountProxyByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -77,9 +77,8 @@ type Server struct {
oAuthConfigProvider idp.OAuthConfigProvider
syncSem atomic.Int32
syncLimEnabled bool
syncLim int32
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
@@ -109,7 +108,6 @@ func NewServer(
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
syncLimEnabled := true
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
@@ -117,9 +115,6 @@ func NewServer(
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
if syncLim < 0 {
syncLimEnabled = false
}
}
}
@@ -139,8 +134,7 @@ func NewServer(
loginFilter: newLoginFilter(),
syncLim: syncLim,
syncLimEnabled: syncLimEnabled,
syncLim: syncLim,
}, nil
}
@@ -218,7 +212,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
@@ -300,7 +294,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -311,7 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
@@ -319,7 +313,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
@@ -336,7 +330,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
@@ -404,20 +398,11 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
// It implements a backpressure mechanism that sends the first update immediately,
// then debounces subsequent rapid updates, ensuring only the latest update is sent
// after a quiet period.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
// Create a debouncer for this peer connection
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
defer debouncer.Stop()
for {
select {
// condition when there are some updates
// todo set the updates channel size to 1
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
@@ -425,38 +410,20 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// Timer expired - quiet period reached, send pending updates if any
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
continue
}
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
for _, pendingUpdate := range pendingUpdates {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err()
}
}
@@ -464,16 +431,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.Send(&proto.EncryptedMessage{
@@ -481,7 +448,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
@@ -513,15 +480,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}

View File

@@ -242,10 +242,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -269,10 +266,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {

View File

@@ -1,103 +0,0 @@
package grpc
import (
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
)
// UpdateDebouncer implements a backpressure mechanism that:
// - Sends the first update immediately
// - Coalesces rapid subsequent network map updates (only latest matters)
// - Queues control/config updates (all must be delivered)
// - Preserves the order of messages (important for control configs between network maps)
// - Ensures pending updates are sent after a quiet period
type UpdateDebouncer struct {
debounceInterval time.Duration
timer *time.Timer
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
timerC <-chan time.Time
}
// NewUpdateDebouncer creates a new debouncer with the specified interval
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
return &UpdateDebouncer{
debounceInterval: interval,
}
}
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
if d.timer == nil {
// No active debounce timer, signal to send immediately
// and start the debounce period
d.startTimer()
return true
}
// Already in debounce period, accumulate this update preserving order
// Check if we should coalesce with the last pending update
if len(d.pendingUpdates) > 0 &&
update.MessageType == network_map.MessageTypeNetworkMap &&
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
// Replace the last network map with this one (coalesce consecutive network maps)
d.pendingUpdates[len(d.pendingUpdates)-1] = update
} else {
// Append to the queue (preserves order for control configs and non-consecutive network maps)
d.pendingUpdates = append(d.pendingUpdates, update)
}
d.resetTimer()
return false
}
// TimerChannel returns the timer channel for select statements
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
if d.timer == nil {
return nil
}
return d.timerC
}
// GetPendingUpdates returns and clears all pending updates after timer expiration.
// Updates are returned in the order they were received, with consecutive network maps
// already coalesced to only the latest one.
// If there were pending updates, it restarts the timer to continue debouncing.
// If there were no pending updates, it clears the timer (true quiet period).
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
updates := d.pendingUpdates
d.pendingUpdates = nil
if len(updates) > 0 {
// There were pending updates, so updates are still coming rapidly
// Restart the timer to continue debouncing mode
if d.timer != nil {
d.timer.Reset(d.debounceInterval)
}
} else {
// No pending updates means true quiet period - return to immediate mode
d.timer = nil
d.timerC = nil
}
return updates
}
// Stop stops the debouncer and cleans up resources
func (d *UpdateDebouncer) Stop() {
if d.timer != nil {
d.timer.Stop()
d.timer = nil
d.timerC = nil
}
d.pendingUpdates = nil
}
func (d *UpdateDebouncer) startTimer() {
d.timer = time.NewTimer(d.debounceInterval)
d.timerC = d.timer.C
}
func (d *UpdateDebouncer) resetTimer() {
d.timer.Stop()
d.timer.Reset(d.debounceInterval)
}

View File

@@ -1,587 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -1,304 +0,0 @@
//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}

View File

@@ -15,7 +15,6 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +82,8 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -115,10 +113,6 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -327,9 +321,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
updateAccountPeers = true
}
@@ -804,19 +795,6 @@ func IsEmbeddedIdp(i idp.Manager) bool {
return ok
}
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
// Returns true only when using embedded IDP with local auth disabled in config.
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
if isNil(i) {
return false
}
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
if !ok {
return false
}
return embeddedIdp.IsLocalAuthDisabled()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
@@ -1679,13 +1657,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1693,20 +1671,8 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}

View File

@@ -6,7 +6,6 @@ import (
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -59,7 +58,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
@@ -115,8 +114,8 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@@ -140,5 +139,4 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
}

View File

@@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1961,82 +1961,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
})
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2059,7 +1983,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3252,7 +3176,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}

View File

@@ -204,9 +204,9 @@ const (
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
ServiceCreated Activity = 108
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
ReverseProxyCreated Activity = 108
ReverseProxyUpdated Activity = 109
ReverseProxyDeleted Activity = 110
AccountDeleted Activity = 99999
)
@@ -342,9 +342,9 @@ var activityMap = map[Activity]Code{
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
ReverseProxyCreated: {"Reverse proxy created", "reverseproxy.create"},
ReverseProxyUpdated: {"Reverse proxy updated", "reverseproxy.update"},
ReverseProxyDeleted: {"Reverse proxy deleted", "reverseproxy.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/types"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
@@ -84,7 +83,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// OAuth callback for proxy authentication
if err := bypass.AddBypassPath(types.ProxyCallbackEndpointFull); err != nil {
if err := bypass.AddBypassPath("/api/oauth/callback"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
@@ -141,15 +140,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
// Check if embedded IdP is enabled for instance manager
// Check if embedded IdP is enabled
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)

View File

@@ -36,22 +36,24 @@ const (
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
settingsManager settings.Manager
accountManager account.Manager
settingsManager settings.Manager
embeddedIdpEnabled bool
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager)
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
return &handler{
accountManager: accountManager,
settingsManager: settingsManager,
accountManager: accountManager,
settingsManager: settingsManager,
embeddedIdpEnabled: embeddedIdpEnabled,
}
}
@@ -163,7 +165,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -290,7 +292,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, &resp)
}
@@ -319,7 +321,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
@@ -339,8 +341,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
LocalAuthDisabled: &settings.LocalAuthDisabled,
EmbeddedIdpEnabled: &embeddedIdpEnabled,
}
if settings.NetworkRange.IsValid() {

View File

@@ -33,6 +33,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
AnyTimes()
return &handler{
embeddedIdpEnabled: false,
accountManager: &mock_server.MockAccountManager{
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
@@ -123,7 +124,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: true,
expectedID: accountID,
@@ -148,7 +148,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -173,7 +172,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -198,7 +196,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -223,7 +220,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -248,7 +244,6 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
return
}
log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired)
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
SetupRequired: setupRequired,
})

View File

@@ -17,7 +17,6 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -27,12 +26,11 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager account.Manager
permissionsManager permissions.Manager
networkMapController network_map.Controller
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager, networkMapController)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -44,11 +42,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
return &Handler{
accountManager: accountManager,
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}
@@ -362,19 +359,13 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userID)
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
user, err := h.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -404,7 +395,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers(), account.GetExposedServicesMap(), account.GetProxyPeers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -13,15 +13,13 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
ugomock "go.uber.org/mock/gomock"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -104,7 +102,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
}
ctrl := ugomock.NewController(t)
ctrl := gomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
@@ -112,10 +110,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
Return("domain").
AnyTimes()
ctrl2 := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl2)
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -205,7 +199,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
},
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}

View File

@@ -2,11 +2,8 @@ package proxy
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
@@ -14,45 +11,24 @@ import (
"golang.org/x/oauth2"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
)
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
type AuthCallbackHandler struct {
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
}
// NewAuthCallbackHandler creates a new OAuth callback handler.
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
rateLimiterConfig := &middleware.RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 15,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
return &AuthCallbackHandler{
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
}
}
// RegisterEndpoints registers the OAuth callback endpoint.
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
router.HandleFunc("/oauth/callback", h.handleCallback).Methods(http.MethodGet)
}
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
return
}
state := r.URL.Query().Get("state")
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
@@ -69,8 +45,10 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
return
}
// Get OIDC configuration
oidcConfig := h.proxyService.GetOIDCConfig()
// Create OIDC provider to discover endpoints
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
if err != nil {
log.WithError(err).Error("Failed to create OIDC provider")
@@ -89,6 +67,7 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
return
}
// Extract user ID from the OIDC token
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
if userID == "" {
log.Error("Failed to extract user ID from OIDC token")
@@ -96,23 +75,21 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
return
}
// Group validation is performed by the proxy via ValidateSession gRPC call.
// This allows the proxy to show 403 pages directly without redirect dance.
// Generate session JWT instead of passing OIDC access_token
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
if err != nil {
log.WithError(err).Error("Failed to create session token")
redirectURL.Scheme = "https"
query := redirectURL.Query()
query.Set("error", "access_denied")
query.Set("error_description", "Service configuration error")
redirectURL.RawQuery = query.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
redirectURL.Scheme = "https"
// Pass the session token in the URL query parameter. The proxy middleware will
// extract it, validate it, set its own cookie, and redirect to remove the token from the URL.
// We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the
// management server cannot set cookies for the proxy's domain.
query := redirectURL.Query()
query.Set("session_token", sessionToken)
redirectURL.RawQuery = query.Encode()
@@ -121,7 +98,9 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// extractUserIDFromToken extracts the user ID from an OIDC token.
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
// Try to get ID token from the oauth2 token extras
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
log.Warn("No id_token in OIDC response")
@@ -138,34 +117,27 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
return ""
}
// Extract claims
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
UserID string `json:"user_id"`
}
if err := idToken.Claims(&claims); err != nil {
log.WithError(err).Warn("Failed to extract claims from ID token")
return ""
}
return claims.Subject
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
// Prefer subject, fall back to user_id or email
if claims.Subject != "" {
return claims.Subject
}
if claims.UserID != "" {
return claims.UserID
}
if claims.Email != "" {
return claims.Email
}
return ""
}

View File

@@ -1,523 +0,0 @@
//go:build integration
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
// fakeOIDCServer creates a minimal OIDC provider for testing.
type fakeOIDCServer struct {
server *httptest.Server
issuer string
signingKey ed25519.PrivateKey
publicKey ed25519.PublicKey
keyID string
tokenSubject string
tokenExpiry time.Duration
failExchange bool
}
func newFakeOIDCServer() *fakeOIDCServer {
pub, priv, _ := ed25519.GenerateKey(rand.Reader)
f := &fakeOIDCServer{
signingKey: priv,
publicKey: pub,
keyID: "test-key-1",
tokenExpiry: time.Hour,
}
f.server = httptest.NewServer(f)
f.issuer = f.server.URL
return f
}
func (f *fakeOIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
f.handleDiscovery(w, r)
case "/token":
f.handleToken(w, r)
case "/keys":
f.handleJWKS(w, r)
default:
http.NotFound(w, r)
}
}
func (f *fakeOIDCServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) {
discovery := map[string]interface{}{
"issuer": f.issuer,
"authorization_endpoint": f.issuer + "/auth",
"token_endpoint": f.issuer + "/token",
"jwks_uri": f.issuer + "/keys",
"response_types_supported": []string{
"code",
"id_token",
"token id_token",
},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"EdDSA"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(discovery)
}
func (f *fakeOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) {
if f.failExchange {
http.Error(w, "invalid_grant", http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
idToken := f.createIDToken()
response := map[string]interface{}{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": idToken,
"refresh_token": "test-refresh-token",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func (f *fakeOIDCServer) createIDToken() string {
now := time.Now()
claims := jwt.MapClaims{
"iss": f.issuer,
"sub": f.tokenSubject,
"aud": "test-client-id",
"exp": now.Add(f.tokenExpiry).Unix(),
"iat": now.Unix(),
"nbf": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
token.Header["kid"] = f.keyID
signed, _ := token.SignedString(f.signingKey)
return signed
}
func (f *fakeOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) {
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "OKP",
"crv": "Ed25519",
"kid": f.keyID,
"x": base64.RawURLEncoding.EncodeToString(f.publicKey),
"use": "sig",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(jwks)
}
func (f *fakeOIDCServer) Close() {
f.server.Close()
}
// testSetup contains all test dependencies.
type testSetup struct {
store store.Store
oidcServer *fakeOIDCServer
proxyService *nbgrpc.ProxyServiceServer
handler *AuthCallbackHandler
router *mux.Router
cleanup func()
}
// testAccessLogManager is a minimal mock for accesslogs.Manager.
type testAccessLogManager struct{}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string) ([]*accesslogs.AccessLogEntry, error) {
return nil, nil
}
func setupAuthCallbackTest(t *testing.T) *testSetup {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
createTestAccountsAndUsers(t, ctx, testStore)
createTestReverseProxies(t, ctx, testStore)
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: oidcServer.issuer,
ClientID: "test-client-id",
Scopes: []string{"openid", "profile", "email"},
CallbackURL: "https://management.example.com/reverse-proxy/callback",
HMACKey: []byte("test-hmac-key-for-state-signing"),
}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService)
router := mux.NewRouter()
handler.RegisterEndpoints(router)
return &testSetup{
store: testStore,
oidcServer: oidcServer,
proxyService: proxyService,
handler: handler,
router: router,
cleanup: func() {
cleanup()
oidcServer.Close()
},
}
}
func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: false,
},
},
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
}
require.NoError(t, testStore.CreateService(ctx, noAuthProxy))
}
func strPtr(s string) *string {
return &s
}
func createTestAccountsAndUsers(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
testAccount := &types.Account{
Id: "testAccountId",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
allowedGroup := &types.Group{
ID: "allowedGroupId",
AccountID: "testAccountId",
Name: "Allowed Group",
Issued: "api",
}
require.NoError(t, testStore.CreateGroup(ctx, allowedGroup))
allowedUser := &types.User{
Id: "allowedUserId",
AccountID: "testAccountId",
Role: types.UserRoleUser,
AutoGroups: []string{"allowedGroupId"},
CreatedAt: time.Now(),
Issued: "api",
}
require.NoError(t, testStore.SaveUser(ctx, allowedUser))
}
// testServiceManager is a minimal implementation for testing.
type testServiceManager struct {
store store.Store
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
}
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()
resp, err := ps.GetOIDCURL(context.Background(), &proto.GetOIDCURLRequest{
RedirectUrl: redirectURL,
AccountId: "testAccountId",
})
require.NoError(t, err)
parsedURL, err := url.Parse(resp.Url)
require.NoError(t, err)
return parsedURL.Query().Get("state")
}
func TestAuthCallback_UserAllowedToLogin(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/dashboard")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
require.NotEmpty(t, location)
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "test-proxy.example.com", parsedLocation.Host)
require.NotEmpty(t, parsedLocation.Query().Get("session_token"), "Should include session token")
require.Empty(t, parsedLocation.Query().Get("error"), "Should not have error parameter")
}
func TestAuthCallback_ProxyNotFound(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
require.NoError(t, setup.store.DeleteService(context.Background(), "testAccountId", "testProxyId"))
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusFound, rec.Code)
location := rec.Header().Get("Location")
parsedLocation, err := url.Parse(location)
require.NoError(t, err)
require.Equal(t, "access_denied", parsedLocation.Query().Get("error"))
}
func TestAuthCallback_InvalidToken(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.failExchange = true
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=invalid-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to exchange code")
}
func TestAuthCallback_ExpiredToken(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
setup.oidcServer.tokenSubject = "allowedUserId"
setup.oidcServer.tokenExpiry = -time.Hour
state := createTestState(t, setup.proxyService, "https://test-proxy.example.com/")
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.Contains(t, rec.Body.String(), "Failed to validate token")
}
func TestAuthCallback_InvalidState(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code&state=invalid-state", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid state")
}
func TestAuthCallback_MissingState(t *testing.T) {
setup := setupAuthCallbackTest(t)
defer setup.cleanup()
req := httptest.NewRequest(http.MethodGet, "/reverse-proxy/callback?code=test-auth-code", nil)
rec := httptest.NewRecorder()
setup.router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -1,156 +0,0 @@
package proxy
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
)
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
req.RemoteAddr = "192.168.1.100:12345"
t.Run("allows requests under limit", func(t *testing.T) {
for i := 0; i < 15; i++ {
allowed := handler.rateLimiter.Allow("192.168.1.100")
assert.True(t, allowed, "Request %d should be allowed", i+1)
}
})
t.Run("blocks requests over limit", func(t *testing.T) {
handler.rateLimiter.Reset("192.168.1.200")
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow("192.168.1.200")
}
allowed := handler.rateLimiter.Allow("192.168.1.200")
assert.False(t, allowed, "Request over limit should be blocked")
})
t.Run("different IPs have separate limits", func(t *testing.T) {
ip1 := "192.168.1.201"
ip2 := "192.168.1.202"
handler.rateLimiter.Reset(ip1)
handler.rateLimiter.Reset(ip2)
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow(ip1)
}
assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked")
assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed")
})
}
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
testIP := "10.0.0.50"
handler.rateLimiter.Reset(testIP)
t.Run("returns 429 when rate limited", func(t *testing.T) {
for i := 0; i < 15; i++ {
handler.rateLimiter.Allow(testIP)
}
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
req.RemoteAddr = testIP + ":12345"
rr := httptest.NewRecorder()
handler.handleCallback(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code")
assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message")
})
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
}{
{
name: "extract from RemoteAddr",
remoteAddr: "192.168.1.100:12345",
expectedIP: "192.168.1.100",
},
{
name: "extract from X-Forwarded-For single IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
expectedIP: "203.0.113.195",
},
{
name: "extract from X-Forwarded-For multiple IPs",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178",
expectedIP: "203.0.113.195",
},
{
name: "extract from X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xRealIP: "198.51.100.42",
expectedIP: "198.51.100.42",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
xRealIP: "198.51.100.42",
expectedIP: "203.0.113.195",
},
{
name: "handle RemoteAddr without port",
remoteAddr: "192.168.1.100",
expectedIP: "192.168.1.100",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected")
})
}
}
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
testIP := "192.168.1.250"
handler.rateLimiter.Reset(testIP)
for i := 0; i < 15; i++ {
allowed := handler.rateLimiter.Allow(testIP)
assert.True(t, allowed, "Should allow request %d within burst limit", i+1)
}
allowed := handler.rateLimiter.Allow(testIP)
assert.False(t, allowed, "Should block request that exceeds burst limit")
}

View File

@@ -205,14 +205,6 @@ func TestCreateInvite(t *testing.T) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "local auth disabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
@@ -384,15 +376,6 @@ func TestAcceptInvite(t *testing.T) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "local auth disabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
},
},
{
name: "missing token",
token: "",

View File

@@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)

View File

@@ -43,11 +43,6 @@ type EmbeddedIdPConfig struct {
Owner *OwnerConfig
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
SignKeyRefreshEnabled bool
// LocalAuthDisabled disables the local (email/password) authentication connector.
// When true, users cannot authenticate via email/password, only via external identity providers.
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -94,8 +89,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
dashboardRedirectURIs := c.DashboardRedirectURIs
baseURL := strings.TrimSuffix(c.Issuer, "/oauth2")
// todo: resolve import cycle
dashboardRedirectURIs = append(dashboardRedirectURIs, baseURL+"/api/reverse-proxy/callback")
dashboardRedirectURIs = append(dashboardRedirectURIs, baseURL+"/api/oauth/callback")
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
@@ -116,8 +110,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
Issuer: "NetBird",
Theme: "light",
},
// Always enable password DB initially - we disable the local connector after startup if needed.
// This ensures Dex has at least one connector during initialization.
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
@@ -205,32 +197,11 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe
return nil, err
}
log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config)
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
if err != nil {
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
}
// If local auth is disabled, validate that other connectors exist
if config.LocalAuthDisabled {
hasOthers, err := provider.HasNonLocalConnectors(ctx)
if err != nil {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("failed to check connectors: %w", err)
}
if !hasOthers {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured")
}
// Ensure local connector is removed (it might exist from a previous run)
if err := provider.DisableLocalAuth(ctx); err != nil {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("failed to disable local auth: %w", err)
}
log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used")
}
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
return &EmbeddedIdPManager{
@@ -315,8 +286,6 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
return nil, fmt.Errorf("failed to list users: %w", err)
}
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users))
indexedUsers := make(map[string][]*UserData)
for _, user := range users {
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
@@ -326,17 +295,11 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
})
}
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID]))
return indexedUsers, nil
}
// CreateUser creates a new user in the embedded IdP.
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.config.LocalAuthDisabled {
return nil, fmt.Errorf("local user creation is disabled")
}
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
@@ -406,10 +369,6 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) (
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
// This is useful for instance setup where the user provides their own password.
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
if m.config.LocalAuthDisabled {
return nil, fmt.Errorf("local user creation is disabled")
}
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
@@ -599,13 +558,3 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string {
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
return defaultUserIDClaim
}
// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration.
func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
return m.config.LocalAuthDisabled
}
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
return m.provider.HasNonLocalConnectors(ctx)
}

View File

@@ -370,234 +370,3 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
})
}
}
func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) {
ctx := context.Background()
t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
_, err = NewEmbeddedIdPManager(ctx, config, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "no other identity providers configured")
})
t.Run("local auth enabled by default", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Verify local auth is enabled by default
assert.False(t, manager.IsLocalAuthDisabled())
})
t.Run("start with local auth disabled when connector exists", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager with local auth enabled and add a connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
// Create a user
userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com")
require.NoError(t, err)
userID := userData.ID
// Add an external connector (Google doesn't require OIDC discovery)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
// Stop the first manager
err = manager1.Stop(ctx)
require.NoError(t, err)
// Now create a new manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Verify local auth is disabled via config
assert.True(t, manager2.IsLocalAuthDisabled())
// Verify the user still exists in storage (just can't login via local)
lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{})
require.NoError(t, err)
assert.Equal(t, "preserved@example.com", lookedUp.Email)
})
t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager and add an external connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
err = manager1.Stop(ctx)
require.NoError(t, err)
// Create manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Try to create a user - should fail
_, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "local user creation is disabled")
})
t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager and add an external connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
err = manager1.Stop(ctx)
require.NoError(t, err)
// Create manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Try to create a user with password - should fail
_, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User")
require.Error(t, err)
assert.Contains(t, err.Error(), "local user creation is disabled")
})
}

View File

@@ -104,22 +104,13 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
}
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
// Check if there are any accounts in the NetBird store
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return err
}
hasAccounts := numAccounts > 0
// Check if there are any users in the embedded IdP (Dex)
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return err
}
hasLocalUsers := len(users) > 0
m.setupMu.Lock()
m.setupRequired = !(hasAccounts || hasLocalUsers)
m.setupRequired = len(users) == 0
m.setupMu.Unlock()
return nil

View File

@@ -610,7 +610,6 @@ func TestSync10PeersGetUpdates(t *testing.T) {
initialPeers := 10
additionalPeers := 10
expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself
var peers []wgtypes.Key
for i := 0; i < initialPeers; i++ {
@@ -619,19 +618,8 @@ func TestSync10PeersGetUpdates(t *testing.T) {
peers = append(peers, key)
}
// Track the maximum peer count each peer has seen
type peerState struct {
mu sync.Mutex
maxPeerCount int
done bool
}
peerStates := make(map[string]*peerState)
for _, pk := range peers {
peerStates[pk.PublicKey().String()] = &peerState{}
}
var wg sync.WaitGroup
wg.Add(initialPeers) // One completion per initial peer
wg.Add(initialPeers + initialPeers*additionalPeers)
var syncClients []mgmtProto.ManagementService_SyncClient
for _, pk := range peers {
@@ -655,9 +643,6 @@ func TestSync10PeersGetUpdates(t *testing.T) {
syncClients = append(syncClients, s)
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
pubKey := pk.PublicKey().String()
state := peerStates[pubKey]
for {
encMsg := &mgmtProto.EncryptedMessage{}
err := syncStream.RecvMsg(encMsg)
@@ -666,28 +651,19 @@ func TestSync10PeersGetUpdates(t *testing.T) {
}
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
if decErr != nil {
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr)
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr)
return
}
resp := &mgmtProto.SyncResponse{}
umErr := pb.Unmarshal(decryptedBytes, resp)
if umErr != nil {
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr)
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr)
return
}
// Track the maximum peer count seen (due to debouncing, updates are coalesced)
peerCount := len(resp.GetRemotePeers())
state.mu.Lock()
if peerCount > state.maxPeerCount {
state.maxPeerCount = peerCount
}
// Signal completion when this peer has seen all expected peers
if !state.done && state.maxPeerCount >= expectedPeerCount {
state.done = true
// We only count if there's a new peer update
if len(resp.GetRemotePeers()) > 0 {
wg.Done()
}
state.mu.Unlock()
}
}(pk, s)
}
@@ -701,30 +677,7 @@ func TestSync10PeersGetUpdates(t *testing.T) {
time.Sleep(time.Duration(n) * time.Millisecond)
}
// Wait for debouncer to flush final updates (debounce interval is 1000ms)
time.Sleep(1500 * time.Millisecond)
// Wait with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - all peers received expected peer count
case <-time.After(5 * time.Second):
// Timeout - report which peers didn't receive all updates
t.Error("Timeout waiting for all peers to receive updates")
for pubKey, state := range peerStates {
state.mu.Lock()
if state.maxPeerCount < expectedPeerCount {
t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount)
}
state.mu.Unlock()
}
}
wg.Wait()
for _, sc := range syncClients {
err := sc.CloseSend()

View File

@@ -12,7 +12,6 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -38,8 +37,8 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -148,10 +147,6 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
// Mock implementation - no-op
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
if am.CreatePeerJobFunc != nil {
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
@@ -219,15 +214,16 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
@@ -327,9 +323,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}

View File

@@ -5,9 +5,6 @@ import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +30,21 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
reverseProxyManager reverseproxy.Manager
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
reverseProxyManager: reverseproxyManager,
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
}
}
@@ -262,14 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event()
}
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -322,14 +309,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}
if serviceID != "" {
return status.NewResourceInUseError(resourceID, serviceID)
}
var events []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)

View File

@@ -103,13 +103,11 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
@@ -117,19 +115,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err
}
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
return err
})
if skipped {
return nil
}
if err != nil {
return err
}
@@ -159,10 +147,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = syncTime
newStatus.LastSeen = time.Now().UTC()
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {
@@ -489,14 +477,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}
if serviceID != "" {
return status.NewPeerInUseError(peerID, serviceID)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
@@ -565,7 +545,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
if setupKey == "" && userID == "" && !peer.ProxyEmbedded {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
}
@@ -594,9 +574,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
var setupKeyID string
var setupKeyName string
var ephemeral bool
var groupsToAdd []string
var allowExtraDNSLabels bool
ephemeral := peer.Ephemeral
switch {
case addedByUser:
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -645,7 +625,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
}
default:
if peer.ProxyMeta.Embedded {
if peer.ProxyEmbedded {
log.WithContext(ctx).Debugf("adding peer for proxy embedded, accountID: %s", accountID)
} else {
log.WithContext(ctx).Warnf("adding peer without setup key or userID, accountID: %s", accountID)
@@ -685,7 +665,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser && !temporary,
Ephemeral: ephemeral,
ProxyMeta: peer.ProxyMeta,
ProxyEmbedded: peer.ProxyEmbedded,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser && !temporary,
ExtraDNSLabels: peer.ExtraDNSLabels,
@@ -752,11 +732,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
if !peer.ProxyMeta.Embedded {
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
switch {

View File

@@ -24,8 +24,6 @@ type Peer struct {
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// ProxyMeta is metadata related to proxy peers
ProxyMeta ProxyMeta `gorm:"embedded;embeddedPrefix:proxy_meta_"`
// Name is peer's name (machine name)
Name string `gorm:"index"`
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
@@ -50,7 +48,8 @@ type Peer struct {
CreatedAt time.Time
// Indicate ephemeral peer attribute
Ephemeral bool `gorm:"index"`
// ProxyEmbedded indicates whether the peer is embedded in a reverse proxy
ProxyEmbedded bool `gorm:"index"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
@@ -60,11 +59,6 @@ type Peer struct {
AllowExtraDNSLabels bool
}
type ProxyMeta struct {
Embedded bool `gorm:"index"`
Cluster string `gorm:"index"`
}
type PeerStatus struct { //nolint:revive
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
@@ -232,7 +226,7 @@ func (p *Peer) Copy() *Peer {
LastLogin: p.LastLogin,
CreatedAt: p.CreatedAt,
Ephemeral: p.Ephemeral,
ProxyMeta: p.ProxyMeta,
ProxyEmbedded: p.ProxyEmbedded,
Location: p.Location,
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
ExtraDNSLabels: slices.Clone(p.ExtraDNSLabels),

View File

@@ -37,5 +37,4 @@ var All = map[Module]struct{}{
SetupKeys: {},
Pats: {},
IdentityProviders: {},
Services: {},
}

View File

@@ -24,28 +24,19 @@ type Manager interface {
UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error)
}
// IdpConfig holds IdP-related configuration that is set at runtime
// and not stored in the database.
type IdpConfig struct {
EmbeddedIdpEnabled bool
LocalAuthDisabled bool
}
type managerImpl struct {
store store.Store
extraSettingsManager extra_settings.Manager
userManager users.Manager
permissionsManager permissions.Manager
idpConfig IdpConfig
}
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager {
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
extraSettingsManager: extraSettingsManager,
userManager: userManager,
permissionsManager: permissionsManager,
idpConfig: idpConfig,
}
}
@@ -83,10 +74,6 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
}
// Fill in IdP-related runtime settings
settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled
settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled
return settings, nil
}

View File

@@ -126,12 +126,11 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{},
&types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
)
if err != nil {
@@ -1281,12 +1280,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
wg.Add(1)
go func() {
defer wg.Done()
services, err := s.getServices(ctx, accountID)
proxies, err := s.getProxies(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Services = services
account.ReverseProxies = proxies
}()
wg.Add(1)
@@ -1690,7 +1689,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster FROM peers WHERE account_id = $1`
location_geo_name_id, proxy_embedded FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -1708,7 +1707,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString
metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString
locationCountryCode, locationCityName, proxyCluster sql.NullString
locationCountryCode, locationCityName sql.NullString
locationGeoNameID sql.NullInt64
)
@@ -1718,7 +1717,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files,
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster)
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded)
if err == nil {
if lastLogin.Valid {
@@ -1803,10 +1802,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
p.Location.GeoNameID = uint(locationGeoNameID.Int64)
}
if proxyEmbedded.Valid {
p.ProxyMeta.Embedded = proxyEmbedded.Bool
}
if proxyCluster.Valid {
p.ProxyMeta.Cluster = proxyCluster.String
p.ProxyEmbedded = proxyEmbedded.Bool
}
if ip != nil {
_ = json.Unmarshal(ip, &p.IP)
@@ -2063,129 +2059,63 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil
}
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
target_id, target_type, enabled
FROM targets WHERE service_id = ANY($1)`
serviceRows, err := s.pool.Query(ctx, serviceQuery, accountID)
func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
const query = `SELECT id, account_id, name, domain, targets, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status
FROM reverse_proxies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
var s reverseproxy.Service
proxies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
var p reverseproxy.ReverseProxy
var auth []byte
var targets []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var status sql.NullString
err := row.Scan(
&s.ID,
&s.AccountID,
&s.Name,
&s.Domain,
&s.Enabled,
&p.ID,
&p.AccountID,
&p.Name,
&p.Domain,
&targets,
&p.Enabled,
&auth,
&createdAt,
&certIssuedAt,
&status,
&proxyCluster,
&s.PassHostHeader,
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
)
if err != nil {
return nil, err
}
// Unmarshal JSON fields
if auth != nil {
if err := json.Unmarshal(auth, &s.Auth); err != nil {
if err := json.Unmarshal(auth, &p.Auth); err != nil {
return nil, err
}
}
s.Meta = reverseproxy.ServiceMeta{}
if targets != nil {
if err := json.Unmarshal(targets, &p.Targets); err != nil {
return nil, err
}
}
p.Meta = reverseproxy.ReverseProxyMeta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
p.Meta.CreatedAt = createdAt.Time
}
if certIssuedAt.Valid {
s.Meta.CertificateIssuedAt = certIssuedAt.Time
p.Meta.CertificateIssuedAt = certIssuedAt.Time
}
if status.Valid {
s.Meta.Status = status.String
}
if proxyCluster.Valid {
s.ProxyCluster = proxyCluster.String
}
if sessionPrivateKey.Valid {
s.SessionPrivateKey = sessionPrivateKey.String
}
if sessionPublicKey.Valid {
s.SessionPublicKey = sessionPublicKey.String
p.Meta.Status = status.String
}
s.Targets = []*reverseproxy.Target{}
return &s, nil
return &p, nil
})
if err != nil {
return nil, err
}
if len(services) == 0 {
return services, nil
}
serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service)
for i, s := range services {
serviceIDs[i] = s.ID
serviceMap[s.ID] = s
}
targetRows, err := s.pool.Query(ctx, targetsQuery, serviceIDs)
if err != nil {
return nil, err
}
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
var t reverseproxy.Target
var path sql.NullString
err := row.Scan(
&t.ID,
&t.AccountID,
&t.ServiceID,
&path,
&t.Host,
&t.Port,
&t.Protocol,
&t.TargetId,
&t.TargetType,
&t.Enabled,
)
if err != nil {
return nil, err
}
if path.Valid {
t.Path = &path.String
}
return &t, nil
})
if err != nil {
return nil, err
}
for _, target := range targets {
if service, ok := serviceMap[target.ServiceID]; ok {
service.Targets = append(service.Targets, target)
}
}
return services, nil
return proxies, nil
}
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
@@ -4379,79 +4309,6 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
return nil
}
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error)
}
return &token, nil
}
// GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error)
}
return tokens, nil
}
// SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
}
return nil
}
// RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy access token not found")
}
return nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC())
if result.Error != nil {
return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy access token not found")
}
return nil
}
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
@@ -4825,148 +4682,133 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil
}
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
}
result := s.db.Create(serviceCopy)
result := s.db.Create(proxyCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create service to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create service to store")
log.WithContext(ctx).Errorf("failed to create reverse proxy to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create reverse proxy to store")
}
return nil
}
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
}
// Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
return err
}
// Update the service and create new targets
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(serviceCopy).Error; err != nil {
return err
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update service to store: %v", err)
return status.Errorf(status.Internal, "failed to update service to store")
result := s.db.Select("*").Save(proxyCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
}
return nil
}
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
func (s *SqlStore) DeleteReverseProxy(ctx context.Context, accountID, proxyID string) error {
result := s.db.Delete(&reverseproxy.ReverseProxy{}, accountAndIDQueryCondition, accountID, proxyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store")
log.WithContext(ctx).Errorf("failed to delete reverse proxy from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete reverse proxy from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "service %s not found", serviceID)
return status.Errorf(status.NotFound, "reverse proxy %s not found", proxyID)
}
return nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var service *reverseproxy.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
var proxy *reverseproxy.ReverseProxy
result := tx.Take(&proxy, accountAndIDQueryCondition, accountID, proxyID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service %s not found", serviceID)
return nil, status.Errorf(status.NotFound, "reverse proxy %s not found", proxyID)
}
log.WithContext(ctx).Errorf("failed to get service from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service from store")
log.WithContext(ctx).Errorf("failed to get reverse proxy from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
return service, nil
return proxy, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
var proxy *reverseproxy.ReverseProxy
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
return nil, status.Errorf(status.NotFound, "reverse proxy with domain %s not found", domain)
}
log.WithContext(ctx).Errorf("failed to get service by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service by domain from store")
log.WithContext(ctx).Errorf("failed to get reverse proxy by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy by domain from store")
}
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
return service, nil
return proxy, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList)
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
}
return serviceList, nil
return proxyList, nil
}
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
var proxyList []*reverseproxy.ReverseProxy
result := tx.Find(&proxyList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
log.WithContext(ctx).Errorf("failed to get reverse proxy from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
}
return serviceList, nil
return proxyList, nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
@@ -4976,11 +4818,11 @@ func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domain
result := tx.Take(&customDomain, accountAndIDQueryCondition, accountID, domainID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "custom domain %s not found", domainID)
return nil, status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID)
}
log.WithContext(ctx).Errorf("failed to get custom domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get custom domain from store")
log.WithContext(ctx).Errorf("failed to get reverse proxy custom domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domain from store")
}
return customDomain, nil
@@ -5003,14 +4845,13 @@ func (s *SqlStore) ListCustomDomains(ctx context.Context, accountID string) ([]*
return domains, nil
}
func (s *SqlStore) CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error) {
func (s *SqlStore) CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*domain.Domain, error) {
newDomain := &domain.Domain{
ID: xid.New().String(), // Generate our own ID because gorm doesn't always configure the database to handle this for us.
Domain: domainName,
AccountID: accountID,
TargetCluster: targetCluster,
Type: domain.TypeCustom,
Validated: validated,
ID: xid.New().String(), // Generate our own ID because gorm doesn't always configure the database to handle this for us.
Domain: domainName,
AccountID: accountID,
Type: domain.TypeCustom,
Validated: validated,
}
result := s.db.Create(newDomain)
if result.Error != nil {
@@ -5051,41 +4892,24 @@ func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.Acc
result := s.db.Create(logEntry)
if result.Error != nil {
log.WithContext(ctx).WithFields(log.Fields{
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"proxy_id": logEntry.ProxyID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
}).Errorf("failed to create access log entry in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create access log entry in store")
}
return nil
}
// GetAccountAccessLogs retrieves access logs for a given account with pagination and filtering
func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
// GetAccountAccessLogs retrieves all access logs for a given account
func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*accesslogs.AccessLogEntry, error) {
var logs []*accesslogs.AccessLogEntry
var totalCount int64
baseQuery := s.db.WithContext(ctx).
Model(&accesslogs.AccessLogEntry{}).
Where(accountIDCondition, accountID)
baseQuery = s.applyAccessLogFilters(baseQuery, filter)
if err := baseQuery.Count(&totalCount).Error; err != nil {
log.WithContext(ctx).Errorf("failed to count access logs: %v", err)
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
}
query := s.db.WithContext(ctx).
Where(accountIDCondition, accountID)
query = s.applyAccessLogFilters(query, filter)
query = query.
Where(accountIDCondition, accountID).
Order("timestamp DESC").
Limit(filter.GetLimit()).
Offset(filter.GetOffset())
Limit(1000)
if lockStrength != LockingStrengthNone {
query = query.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -5094,82 +4918,8 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
result := query.Find(&logs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get access logs from store: %v", result.Error)
return nil, 0, status.Errorf(status.Internal, "failed to get access logs from store")
return nil, status.Errorf(status.Internal, "failed to get access logs from store")
}
return logs, totalCount, nil
}
// applyAccessLogFilters applies filter conditions to the query
func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB {
if filter.Search != nil {
searchPattern := "%" + *filter.Search + "%"
query = query.Where(
"id LIKE ? OR location_connection_ip LIKE ? OR host LIKE ? OR path LIKE ? OR CONCAT(host, path) LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern,
)
}
if filter.SourceIP != nil {
query = query.Where("location_connection_ip = ?", *filter.SourceIP)
}
if filter.Host != nil {
query = query.Where("host = ?", *filter.Host)
}
if filter.Path != nil {
// Support LIKE pattern for path filtering
query = query.Where("path LIKE ?", "%"+*filter.Path+"%")
}
if filter.UserID != nil {
query = query.Where("user_id = ?", *filter.UserID)
}
if filter.Method != nil {
query = query.Where("method = ?", *filter.Method)
}
if filter.Status != nil {
if *filter.Status == "success" {
query = query.Where("status_code >= ? AND status_code < ?", 200, 400)
} else if *filter.Status == "failed" {
query = query.Where("status_code < ? OR status_code >= ?", 200, 400)
}
}
if filter.StatusCode != nil {
query = query.Where("status_code = ?", *filter.StatusCode)
}
if filter.StartDate != nil {
query = query.Where("timestamp >= ?", *filter.StartDate)
}
if filter.EndDate != nil {
query = query.Where("timestamp <= ?", *filter.EndDate)
}
return query
}
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var target *reverseproxy.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service target with ID %s not found", targetID)
}
log.WithContext(ctx).Errorf("failed to get service target from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get service target from store")
}
return target, nil
return logs, nil
}

View File

@@ -109,12 +109,6 @@ type Store interface {
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, userID, patID string) error
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
@@ -250,24 +244,23 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
CreateReverseProxy(ctx context.Context, service *reverseproxy.ReverseProxy) error
UpdateReverseProxy(ctx context.Context, service *reverseproxy.ReverseProxy) error
DeleteReverseProxy(ctx context.Context, accountID, serviceID string) error
GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error)
GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error)
GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*domain.Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*accesslogs.AccessLogEntry, error)
}
const (

View File

@@ -1,17 +0,0 @@
-- Schema definitions (must match GORM auto-migrate order)
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
-- Test accounts
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
-- Test groups
INSERT INTO "groups" VALUES('allowedGroupId','testAccountId','Allowed Group','api','[]',0,'');
INSERT INTO "groups" VALUES('restrictedGroupId','testAccountId','Restricted Group','api','[]',0,'');
-- Test users
INSERT INTO users VALUES('allowedUserId','testAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('nonGroupUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherAccountUserId','otherAccountId','user',0,0,'','["allowedGroupId"]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');

View File

@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
ReverseProxies []*reverseproxy.ReverseProxy `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -285,6 +285,8 @@ func (a *Account) GetPeerNetworkMap(
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
exposedServices map[string][]*reverseproxy.ReverseProxy, // routerPeer -> list of exposed services
proxyPeers []*nbpeer.Peer,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
@@ -302,10 +304,21 @@ func (a *Account) GetPeerNetworkMap(
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
var aclPeers []*nbpeer.Peer
var firewallRules []*FirewallRule
var authorizedUsers map[string]map[string]struct{}
var enableSSH bool
if peer.ProxyEmbedded {
aclPeers, firewallRules = a.GetProxyConnectionResources(exposedServices)
} else {
aclPeers, firewallRules, authorizedUsers, enableSSH = a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
proxyAclPeers, proxyFirewallRules := a.GetPeerProxyResources(exposedServices[peerID], proxyPeers)
aclPeers = append(aclPeers, proxyAclPeers...)
firewallRules = append(firewallRules, proxyFirewallRules...)
}
var peersToConnect, expiredPeers []*nbpeer.Peer
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
@@ -374,11 +387,9 @@ func (a *Account) GetPeerNetworkMap(
return nm
}
// GetProxyConnectionResources returns ACL peers for the proxy-embedded peer based on exposed services.
// No firewall rules are generated here; the proxy peer is always a new on-demand client with a stateful
// firewall, so OUT rules are unnecessary. Inbound rules are handled on the target/router peer side.
func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServices map[string][]*reverseproxy.Service) []*nbpeer.Peer {
func (a *Account) GetProxyConnectionResources(exposedServices map[string][]*reverseproxy.ReverseProxy) ([]*nbpeer.Peer, []*FirewallRule) {
var aclPeers []*nbpeer.Peer
var firewallRules []*FirewallRule
for _, peerServices := range exposedServices {
for _, service := range peerServices {
@@ -389,24 +400,32 @@ func (a *Account) GetProxyConnectionResources(ctx context.Context, exposedServic
if !target.Enabled {
continue
}
if target.TargetType == reverseproxy.TargetTypePeer {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
tpeer := a.GetPeer(target.TargetId)
if tpeer == nil {
continue
}
aclPeers = append(aclPeers, tpeer)
firewallRules = append(firewallRules, &FirewallRule{
PolicyID: "proxy-" + service.ID,
PeerIP: tpeer.IP.String(),
Direction: FirewallRuleDirectionOUT,
Action: "allow",
Protocol: string(PolicyRuleProtocolTCP),
PortRange: RulePortRange{Start: uint16(target.Port), End: uint16(target.Port)},
})
case reverseproxy.TargetTypeResource:
// TODO: handle resource type targets
}
}
}
}
return aclPeers
return aclPeers, firewallRules
}
// GetPeerProxyResources returns ACL peers and inbound firewall rules for a peer that is targeted by reverse proxy services.
// Only IN rules are generated; OUT rules are omitted since proxy peers are always new clients with stateful firewalls.
// Rules use PortRange only (not the legacy Port field) as this feature only targets current peer versions.
func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.Service, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerProxyResources(services []*reverseproxy.ReverseProxy, proxyPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*FirewallRule) {
var aclPeers []*nbpeer.Peer
var firewallRules []*FirewallRule
@@ -418,24 +437,18 @@ func (a *Account) GetPeerProxyResources(peerID string, services []*reverseproxy.
if !target.Enabled {
continue
}
aclPeers = proxyPeers
needsPeerRules := (target.TargetType == reverseproxy.TargetTypePeer && target.TargetId == peerID) ||
(target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain)
if needsPeerRules {
for _, proxyPeer := range proxyPeers {
firewallRules = append(firewallRules, &FirewallRule{
PolicyID: "proxy-" + service.ID,
PeerIP: proxyPeer.IP.String(),
Direction: FirewallRuleDirectionIN,
Action: "allow",
Protocol: string(PolicyRuleProtocolTCP),
PortRange: RulePortRange{Start: uint16(target.Port), End: uint16(target.Port)},
})
}
for _, peer := range aclPeers {
firewallRules = append(firewallRules, &FirewallRule{
PolicyID: "proxy-" + service.ID,
PeerIP: peer.IP.String(),
Direction: FirewallRuleDirectionIN,
Action: "allow",
Protocol: string(PolicyRuleProtocolTCP),
PortRange: RulePortRange{Start: uint16(target.Port), End: uint16(target.Port)},
})
}
// TODO: handle routes
}
}
@@ -1285,7 +1298,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peer, ok := a.Peers[p]
if !ok || peer == nil || peer.ProxyMeta.Embedded {
if !ok || peer == nil || peer.ProxyEmbedded {
continue
}
@@ -1312,7 +1325,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
peer := a.GetPeer(resource.ID)
if peer == nil {
if peer == nil || peer.ProxyEmbedded {
return []*nbpeer.Peer{}, false
}
@@ -1848,152 +1861,38 @@ func (a *Account) GetActiveGroupUsers() map[string][]string {
return groups
}
func (a *Account) GetProxyPeers() map[string][]*nbpeer.Peer {
proxyPeers := make(map[string][]*nbpeer.Peer)
func (a *Account) GetProxyPeers() []*nbpeer.Peer {
var proxyPeers []*nbpeer.Peer
for _, peer := range a.Peers {
if peer.ProxyMeta.Embedded {
proxyPeers[peer.ProxyMeta.Cluster] = append(proxyPeers[peer.ProxyMeta.Cluster], peer)
if peer.ProxyEmbedded {
proxyPeers = append(proxyPeers, peer)
}
}
return proxyPeers
}
func (a *Account) GetPeerProxyRoutes(ctx context.Context, peer *nbpeer.Peer, proxies map[string][]*reverseproxy.Service, resourcesMap map[string]*resourceTypes.NetworkResource, routers map[string]map[string]*routerTypes.NetworkRouter, proxyPeers []*nbpeer.Peer) ([]*route.Route, []*RouteFirewallRule, []*nbpeer.Peer) {
sourceRanges := make([]string, 0, len(proxyPeers))
for _, proxyPeer := range proxyPeers {
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, proxyPeer.IP))
}
peers := make(map[string]*nbpeer.Peer, len(resourcesMap))
var routes []*route.Route
var firewallRules []*RouteFirewallRule
for _, proxyPerResource := range proxies {
for _, proxy := range proxyPerResource {
for _, target := range proxy.Targets {
if target.TargetType == reverseproxy.TargetTypeHost || target.TargetType == reverseproxy.TargetTypeSubnet || target.TargetType == reverseproxy.TargetTypeDomain {
resource, ok := resourcesMap[target.TargetId]
if !ok {
log.WithContext(ctx).Warnf("proxy target %s not found in resources map", target.TargetId)
continue
}
networkRouters, ok := routers[resource.NetworkID]
if !ok {
log.WithContext(ctx).Warnf("proxy target %s not found in routers map", target.TargetId)
continue
}
for peerID, router := range networkRouters {
routePeer := a.GetPeer(peerID)
route := resource.ToRoute(routePeer, router)
routes = append(routes, route)
rule := RouteFirewallRule{
PolicyID: fmt.Sprintf("proxy-%s-%s", proxy.ID, route.ID),
RouteID: route.ID,
SourceRanges: sourceRanges,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolTCP),
Domains: route.Domains,
IsDynamic: route.IsDynamic(),
PortRange: RulePortRange{
Start: uint16(target.Port),
End: uint16(target.Port),
},
}
firewallRules = append(firewallRules, &rule)
peers[peerID] = routePeer
}
}
}
}
}
resultPeers := make([]*nbpeer.Peer, 0, len(peers))
for _, peer := range peers {
resultPeers = append(resultPeers, peer)
}
return routes, firewallRules, resultPeers
}
func (a *Account) GetResourcesMap() map[string]*resourceTypes.NetworkResource {
resourcesMap := make(map[string]*resourceTypes.NetworkResource, len(a.NetworkResources))
func (a *Account) GetExposedServicesMap() map[string][]*reverseproxy.ReverseProxy {
services := make(map[string][]*reverseproxy.ReverseProxy)
resourcesMap := make(map[string]*resourceTypes.NetworkResource)
for _, resource := range a.NetworkResources {
resourcesMap[resource.ID] = resource
}
return resourcesMap
}
func (a *Account) InjectProxyPolicies(ctx context.Context) {
if len(a.Services) == 0 {
return
}
proxyPeersByCluster := a.GetProxyPeers()
if len(proxyPeersByCluster) == 0 {
return
}
for _, service := range a.Services {
if !service.Enabled {
continue
}
for _, target := range service.Targets {
if !target.Enabled {
continue
}
for _, proxyPeer := range proxyPeersByCluster[service.ProxyCluster] {
port := target.Port
if port == 0 {
switch target.Protocol {
case "https":
port = 443
case "http":
port = 80
default:
log.WithContext(ctx).Warnf("unsupported protocol %s for proxy target %s, skipping policy injection", target.Protocol, target.TargetId)
continue
}
routersMap := a.GetResourceRoutersMap()
for _, proxy := range a.ReverseProxies {
for _, target := range proxy.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
services[target.TargetId] = append(services[target.TargetId], proxy)
case reverseproxy.TargetTypeResource:
resource := resourcesMap[target.TargetId]
routers := routersMap[resource.NetworkID]
for peerID := range routers {
services[peerID] = append(services[peerID], proxy)
}
path := ""
if target.Path != nil {
path = *target.Path
}
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
a.Policies = append(a.Policies, &Policy{
ID: policyID,
Name: fmt.Sprintf("Proxy Access to %s", service.Name),
Enabled: true,
Rules: []*PolicyRule{
{
ID: policyID,
PolicyID: policyID,
Name: fmt.Sprintf("Allow access to %s", service.Name),
Enabled: true,
SourceResource: Resource{
ID: proxyPeer.ID,
Type: ResourceTypePeer,
},
DestinationResource: Resource{
ID: target.TargetId,
Type: ResourceType(target.TargetType),
},
Bidirectional: false,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
PortRanges: []RulePortRange{
{
Start: uint16(port),
End: uint16(port),
},
},
},
},
})
}
}
}
return services
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules

View File

@@ -70,7 +70,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, nil, account.GetActiveGroupUsers(), nil, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, nil, account.GetActiveGroupUsers(), nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})

View File

@@ -1,7 +0,0 @@
package types
// ProxyCallbackEndpoint holds the proxy callback endpoint
const ProxyCallbackEndpoint = "/reverse-proxy/callback"
// ProxyCallbackEndpointFull holds the proxy callback endpoint with api suffix
const ProxyCallbackEndpointFull = "/api" + ProxyCallbackEndpoint

View File

@@ -1,137 +0,0 @@
package types
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/management/server/util"
)
const (
// ProxyTokenPrefix is the globally used prefix for proxy access tokens
ProxyTokenPrefix = "nbx_"
// ProxyTokenSecretLength is the number of characters used for the secret
ProxyTokenSecretLength = 30
// ProxyTokenChecksumLength is the number of characters used for the encoded checksum
ProxyTokenChecksumLength = 6
// ProxyTokenLength is the total number of characters used for the token
ProxyTokenLength = 40
)
// HashedProxyToken is a SHA-256 hash of a plain proxy token, base64-encoded.
type HashedProxyToken string
// PlainProxyToken is the raw token string displayed once at creation time.
type PlainProxyToken string
// ProxyAccessToken holds information about a proxy access token including a hashed version for verification
type ProxyAccessToken struct {
ID string `gorm:"primaryKey"`
Name string
HashedToken HashedProxyToken `gorm:"uniqueIndex"`
// AccountID is nil for management-wide tokens, set for account-scoped tokens
AccountID *string `gorm:"index"`
ExpiresAt *time.Time
CreatedBy string
CreatedAt time.Time
LastUsed *time.Time
Revoked bool
}
// IsExpired returns true if the token has expired
func (t *ProxyAccessToken) IsExpired() bool {
if t.ExpiresAt == nil {
return false
}
return time.Now().After(*t.ExpiresAt)
}
// IsValid returns true if the token is not revoked and not expired
func (t *ProxyAccessToken) IsValid() bool {
return !t.Revoked && !t.IsExpired()
}
// ProxyAccessTokenGenerated holds the new token and the plain text version
type ProxyAccessTokenGenerated struct {
PlainToken PlainProxyToken
ProxyAccessToken
}
// CreateNewProxyAccessToken generates a new proxy access token.
// Returns the token with hashed value stored and plain token for one-time display.
func CreateNewProxyAccessToken(name string, expiresIn time.Duration, accountID *string, createdBy string) (*ProxyAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateProxyToken()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
var expiresAt *time.Time
if expiresIn > 0 {
expiresAt = util.ToPtr(currentTime.Add(expiresIn))
}
return &ProxyAccessTokenGenerated{
ProxyAccessToken: ProxyAccessToken{
ID: xid.New().String(),
Name: name,
HashedToken: hashedToken,
AccountID: accountID,
ExpiresAt: expiresAt,
CreatedBy: createdBy,
CreatedAt: currentTime,
Revoked: false,
},
PlainToken: plainToken,
}, nil
}
func generateProxyToken() (HashedProxyToken, PlainProxyToken, error) {
secret, err := b.Random(ProxyTokenSecretLength)
if err != nil {
return "", "", err
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
plainToken := PlainProxyToken(ProxyTokenPrefix + secret + paddedChecksum)
return plainToken.Hash(), plainToken, nil
}
// Hash returns the SHA-256 hash of the plain token, base64-encoded.
func (t PlainProxyToken) Hash() HashedProxyToken {
h := sha256.Sum256([]byte(t))
return HashedProxyToken(base64.StdEncoding.EncodeToString(h[:]))
}
// Validate checks the format of a proxy token without checking the database.
func (t PlainProxyToken) Validate() error {
if !strings.HasPrefix(string(t), ProxyTokenPrefix) {
return fmt.Errorf("invalid token prefix")
}
if len(t) != ProxyTokenLength {
return fmt.Errorf("invalid token length")
}
secret := t[len(ProxyTokenPrefix) : len(t)-ProxyTokenChecksumLength]
checksumStr := t[len(t)-ProxyTokenChecksumLength:]
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
expectedChecksumStr := fmt.Sprintf("%06s", base62.Encode(expectedChecksum))
if string(checksumStr) != expectedChecksumStr {
return fmt.Errorf("invalid token checksum")
}
return nil
}

View File

@@ -1,155 +0,0 @@
package types
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPlainProxyToken_Validate(t *testing.T) {
tests := []struct {
name string
token PlainProxyToken
wantErr bool
errMsg string
}{
{
name: "valid token",
token: "", // will be generated
wantErr: false,
},
{
name: "wrong prefix",
token: "xyz_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "too short",
token: "nbx_short",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "too long",
token: "nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNMextra",
wantErr: true,
errMsg: "invalid token length",
},
{
name: "correct length but invalid checksum",
token: "nbx_invalidtoken123456789012345678901234", // exactly 40 chars, invalid checksum
wantErr: true,
errMsg: "invalid token checksum",
},
{
name: "empty token",
token: "",
wantErr: true,
errMsg: "invalid token prefix",
},
{
name: "only prefix",
token: "nbx_",
wantErr: true,
errMsg: "invalid token length",
},
}
// Generate a valid token for the first test
generated, err := CreateNewProxyAccessToken("test", 0, nil, "test")
require.NoError(t, err)
tests[0].token = generated.PlainToken
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.token.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestPlainProxyToken_Hash(t *testing.T) {
token1 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token2 := PlainProxyToken("nbx_8FbPkxioCFmlvCTJbD1RafygfVmS9z15lyNM")
token3 := PlainProxyToken("nbx_differenttoken1234567890123456789X")
hash1 := token1.Hash()
hash2 := token2.Hash()
hash3 := token3.Hash()
assert.Equal(t, hash1, hash2, "same token should produce same hash")
assert.NotEqual(t, hash1, hash3, "different tokens should produce different hashes")
assert.NotEmpty(t, hash1)
}
func TestCreateNewProxyAccessToken(t *testing.T) {
t.Run("creates valid token", func(t *testing.T) {
generated, err := CreateNewProxyAccessToken("test-token", 0, nil, "test-user")
require.NoError(t, err)
assert.NotEmpty(t, generated.ID)
assert.Equal(t, "test-token", generated.Name)
assert.Equal(t, "test-user", generated.CreatedBy)
assert.NotEmpty(t, generated.HashedToken)
assert.NotEmpty(t, generated.PlainToken)
assert.Nil(t, generated.ExpiresAt)
assert.False(t, generated.Revoked)
assert.NoError(t, generated.PlainToken.Validate())
assert.Equal(t, ProxyTokenLength, len(generated.PlainToken))
assert.Equal(t, ProxyTokenPrefix, string(generated.PlainToken[:len(ProxyTokenPrefix)]))
})
t.Run("tokens are unique", func(t *testing.T) {
gen1, err := CreateNewProxyAccessToken("test1", 0, nil, "user")
require.NoError(t, err)
gen2, err := CreateNewProxyAccessToken("test2", 0, nil, "user")
require.NoError(t, err)
assert.NotEqual(t, gen1.PlainToken, gen2.PlainToken)
assert.NotEqual(t, gen1.HashedToken, gen2.HashedToken)
assert.NotEqual(t, gen1.ID, gen2.ID)
})
}
func TestProxyAccessToken_IsExpired(t *testing.T) {
past := time.Now().Add(-1 * time.Hour)
future := time.Now().Add(1 * time.Hour)
t.Run("expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &past}
assert.True(t, token.IsExpired())
})
t.Run("not expired token", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: &future}
assert.False(t, token.IsExpired())
})
t.Run("no expiration", func(t *testing.T) {
token := &ProxyAccessToken{ExpiresAt: nil}
assert.False(t, token.IsExpired())
})
}
func TestProxyAccessToken_IsValid(t *testing.T) {
token := &ProxyAccessToken{
Revoked: false,
}
assert.True(t, token.IsValid())
token.Revoked = true
assert.False(t, token.IsValid())
}

View File

@@ -55,14 +55,6 @@ type Settings struct {
// AutoUpdateVersion client auto-update version
AutoUpdateVersion string `gorm:"default:'disabled'"`
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
// This is a runtime-only field, not stored in the database.
EmbeddedIdpEnabled bool `gorm:"-"`
// LocalAuthDisabled indicates if local (email/password) authentication is disabled.
// This is a runtime-only field, not stored in the database.
LocalAuthDisabled bool `gorm:"-"`
}
// Copy copies the Settings struct
@@ -84,8 +76,6 @@ func (s *Settings) Copy() *Settings {
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
AutoUpdateVersion: s.AutoUpdateVersion,
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
LocalAuthDisabled: s.LocalAuthDisabled,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -191,10 +191,6 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
// Unlike createNewIdpUser, this method fetches user data directly from the database
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
if IsLocalAuthDisabled(ctx, am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
if err != nil {
return nil, fmt.Errorf("failed to get inviter user: %w", err)
@@ -1466,10 +1462,6 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if IsLocalAuthDisabled(ctx, am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
if err := validateUserInvite(invite); err != nil {
return nil, err
}
@@ -1629,10 +1621,6 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if IsLocalAuthDisabled(ctx, am.idpManager) {
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}

View File

@@ -10,14 +10,14 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
echo "netbird:x:1000:netbird" > /tmp/group && \
mkdir -p /tmp/var/lib/netbird && \
mkdir -p /tmp/certs
mkdir -p /tmp/cert
FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
COPY --from=builder --chown=1000:1000 /tmp/certs /certs
COPY --from=builder --chown=1000:1000 /tmp/cert /cert
USER netbird:netbird
ENV HOME=/var/lib/netbird
ENV NB_PROXY_ADDRESS=":8443"

View File

@@ -4,7 +4,6 @@ package auth
import (
"crypto/ed25519"
"crypto/tls"
"fmt"
"time"
@@ -29,21 +28,6 @@ const (
SessionJWTIssuer = "netbird-management"
)
// ResolveProto determines the protocol scheme based on the forwarded proto
// configuration. When set to "http" or "https" the value is used directly.
// Otherwise TLS state is used: if conn is non-nil "https" is returned, else "http".
func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
switch forwardedProto {
case "http", "https":
return forwardedProto
default:
if conn != nil {
return "https"
}
return "http"
}
}
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
if publicKey == nil {

View File

@@ -68,14 +68,8 @@ var debugPingCmd = &cobra.Command{
SilenceUsage: true,
}
var debugLogCmd = &cobra.Command{
Use: "log",
Short: "Manage client logging",
Long: "Commands to manage logging settings for a client connected through the proxy.",
}
var debugLogLevelCmd = &cobra.Command{
Use: "level <account-id> <level>",
Use: "loglevel <account-id> <level>",
Short: "Set client log level",
Long: "Set the log level for a client (trace, debug, info, warn, error).",
Args: cobra.ExactArgs(2),
@@ -115,8 +109,7 @@ func init() {
debugCmd.AddCommand(debugStatusCmd)
debugCmd.AddCommand(debugSyncCmd)
debugCmd.AddCommand(debugPingCmd)
debugLogCmd.AddCommand(debugLogLevelCmd)
debugCmd.AddCommand(debugLogCmd)
debugCmd.AddCommand(debugLogLevelCmd)
debugCmd.AddCommand(debugStartCmd)
debugCmd.AddCommand(debugStopCmd)

View File

@@ -2,27 +2,20 @@ package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme"
"github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util"
)
const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token.
const envProxyToken = "NB_PROXY_TOKEN"
var (
Version = "dev"
Commit = "unknown"
@@ -46,21 +39,14 @@ var (
oidcClientSecret string
oidcEndpoint string
oidcScopes string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wgPort int
)
var rootCmd = &cobra.Command{
Use: "proxy",
Short: "NetBird reverse proxy server",
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
Version: Version,
SilenceUsage: true,
RunE: runServer,
Use: "proxy",
Short: "NetBird reverse proxy server",
Long: "NetBird reverse proxy server for proxying traffic to NetBird networks.",
Version: Version,
RunE: runServer,
}
func init() {
@@ -79,12 +65,6 @@ func init() {
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
}
// Execute runs the root command.
@@ -105,11 +85,6 @@ func SetVersionInfo(version, commit, buildDate, goVersion string) {
}
func runServer(cmd *cobra.Command, args []string) error {
proxyToken := os.Getenv(envProxyToken)
if proxyToken == "" {
return fmt.Errorf("proxy token is required: set %s environment variable", envProxyToken)
}
level := "error"
if debugLogs {
level = "debug"
@@ -120,26 +95,12 @@ func runServer(cmd *cobra.Command, args []string) error {
log.Infof("configured log level: %s", level)
switch forwardedProto {
case "auto", "http", "https":
default:
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
}
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
if err != nil {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.Server{
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
ProxyURL: proxyURL,
ProxyToken: proxyToken,
CertificateDirectory: certDir,
CertificateFile: certFile,
CertificateKeyFile: certKeyFile,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,
@@ -150,16 +111,9 @@ func runServer(cmd *cobra.Command, args []string) error {
OIDCClientSecret: oidcClientSecret,
OIDCEndpoint: oidcEndpoint,
OIDCScopes: strings.Split(oidcScopes, ","),
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
WireguardPort: wgPort,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
log.Fatal(err)
}
return nil
@@ -184,15 +138,3 @@ func envStringOrDefault(key string, def string) string {
}
return v
}
func envIntOrDefault(key string, def int) int {
v, exists := os.LookupEnv(key)
if !exists {
return def
}
parsed, err := strconv.Atoi(v)
if err != nil {
return def
}
return parsed
}

View File

@@ -0,0 +1,108 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: netbird-proxy
labels:
app: netbird-proxy
spec:
replicas: 1
selector:
matchLabels:
app: netbird-proxy
template:
metadata:
labels:
app: netbird-proxy
spec:
hostAliases:
- ip: "192.168.100.1"
hostnames:
- "host.docker.internal"
containers:
- name: proxy
image: netbird-proxy
ports:
- containerPort: 8443
name: https
- containerPort: 8080
name: health
- containerPort: 8444
name: debug
env:
- name: USER
value: "netbird"
- name: HOME
value: "/tmp"
- name: NB_PROXY_DEBUG_LOGS
value: "true"
- name: NB_PROXY_MANAGEMENT_ADDRESS
value: "http://host.docker.internal:8080"
- name: NB_PROXY_ADDRESS
value: ":8443"
- name: NB_PROXY_HEALTH_ADDRESS
value: ":8080"
- name: NB_PROXY_DEBUG_ENDPOINT
value: "true"
- name: NB_PROXY_DEBUG_ENDPOINT_ADDRESS
value: ":8444"
- name: NB_PROXY_URL
value: "https://proxy.local"
- name: NB_PROXY_CERTIFICATE_DIRECTORY
value: "/certs"
volumeMounts:
- name: tls-certs
mountPath: /certs
readOnly: true
livenessProbe:
httpGet:
path: /healthz/live
port: health
initialDelaySeconds: 5
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
httpGet:
path: /healthz/ready
port: health
initialDelaySeconds: 5
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
startupProbe:
httpGet:
path: /healthz/startup
port: health
periodSeconds: 2
timeoutSeconds: 10
failureThreshold: 60
resources:
requests:
memory: "64Mi"
cpu: "100m"
limits:
memory: "256Mi"
cpu: "500m"
volumes:
- name: tls-certs
secret:
secretName: netbird-proxy-tls
---
apiVersion: v1
kind: Service
metadata:
name: netbird-proxy
spec:
selector:
app: netbird-proxy
ports:
- name: https
port: 8443
targetPort: 8443
- name: health
port: 8080
targetPort: 8080
- name: debug
port: 8444
targetPort: 8444
type: ClusterIP

View File

@@ -0,0 +1,11 @@
kind: Cluster
apiVersion: kind.x-k8s.io/v1alpha4
nodes:
- role: control-plane
extraPortMappings:
- containerPort: 30080
hostPort: 30080
protocol: TCP
- containerPort: 30443
hostPort: 30443
protocol: TCP

View File

@@ -2,7 +2,6 @@ package accesslog
import (
"context"
"net/netip"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -15,24 +14,18 @@ type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
// Logger sends access log entries to the management server via gRPC.
type Logger struct {
client gRPCClient
logger *log.Logger
trustedProxies []netip.Prefix
client gRPCClient
logger *log.Logger
}
// NewLogger creates a new access log Logger. The trustedProxies parameter
// configures which upstream proxy IP ranges are trusted for extracting
// the real client IP from X-Forwarded-For headers.
func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Prefix) *Logger {
func NewLogger(client gRPCClient, logger *log.Logger) *Logger {
if logger == nil {
logger = log.StandardLogger()
}
return &Logger{
client: client,
logger: logger,
trustedProxies: trustedProxies,
client: client,
logger: logger,
}
}

View File

@@ -3,43 +3,34 @@ package accesslog
import (
"net"
"net/http"
"strings"
"time"
"github.com/rs/xid"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/web"
)
func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip logging for internal proxy assets (CSS, JS, etc.)
if strings.HasPrefix(r.URL.Path, web.PathPrefix+"/") {
next.ServeHTTP(w, r)
return
}
// Generate request ID early so it can be used by error pages and log correlation.
requestID := xid.New().String()
l.logger.Debugf("request: request_id=%s method=%s host=%s path=%s", requestID, r.Method, r.Host, r.URL.Path)
l.logger.Debugf("access log middleware invoked for %s %s", r.Method, r.URL.Path)
// Use a response writer wrapper so we can access the status code later.
sw := &statusWriter{
w: w,
status: http.StatusOK,
status: http.StatusOK, // Default status is OK unless otherwise modified.
}
// Resolve the source IP using trusted proxy configuration before passing
// the request on, as the proxy will modify forwarding headers.
sourceIp := extractSourceIP(r, l.trustedProxies)
// Get the source IP before passing the request on as the proxy will modify
// headers that we wish to use to gather that information on the request.
sourceIp := extractSourceIP(r)
// Generate request ID early so it can be used by error pages.
requestID := xid.New().String()
// Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID}
capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now()
@@ -62,13 +53,10 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
Method: r.Method,
ResponseCode: int32(sw.status),
SourceIp: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
AuthMechanism: auth.MethodFromContext(r.Context()).String(),
UserId: auth.UserFromContext(r.Context()),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
l.log(r.Context(), entry)
})
}

Some files were not shown because too many files have changed in this diff Show More