mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
1 Commits
feature/ad
...
deploy/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c049f6f09 |
@@ -1,202 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
|
||||||
var cache cache.DNSConfigCache
|
|
||||||
|
|
||||||
// Create two different configs
|
|
||||||
config1 := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "example.com",
|
|
||||||
Records: []nbdns.SimpleRecord{
|
|
||||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
ID: "group1",
|
|
||||||
Name: "Group 1",
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config2 := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "example.org",
|
|
||||||
Records: []nbdns.SimpleRecord{
|
|
||||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
ID: "group2",
|
|
||||||
Name: "Group 2",
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// First run with config1
|
|
||||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
|
||||||
|
|
||||||
// Second run with config2
|
|
||||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
|
||||||
|
|
||||||
// Third run with config1 again
|
|
||||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
|
||||||
|
|
||||||
// Verify that result1 and result3 are identical
|
|
||||||
if !reflect.DeepEqual(result1, result3) {
|
|
||||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that result2 is different from result1 and result3
|
|
||||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
|
||||||
t.Errorf("Results should be different for different inputs")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
|
||||||
t.Errorf("Cache should contain name server group 'group1'")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
|
||||||
t.Errorf("Cache should contain name server group 'group2'")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
|
||||||
sizes := []int{10, 100, 1000}
|
|
||||||
|
|
||||||
for _, size := range sizes {
|
|
||||||
testData := generateTestData(size)
|
|
||||||
|
|
||||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
|
||||||
cache := &cache.DNSConfigCache{}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
cache := &cache.DNSConfigCache{}
|
|
||||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTestData(size int) nbdns.Config {
|
|
||||||
config := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: make([]nbdns.CustomZone, size),
|
|
||||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < size; i++ {
|
|
||||||
config.CustomZones[i] = nbdns.CustomZone{
|
|
||||||
Domain: fmt.Sprintf("domain%d.com", i),
|
|
||||||
Records: []nbdns.SimpleRecord{
|
|
||||||
{
|
|
||||||
Name: fmt.Sprintf("record%d", i),
|
|
||||||
Type: 1,
|
|
||||||
Class: "IN",
|
|
||||||
TTL: 3600,
|
|
||||||
RData: "192.168.1.1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
|
||||||
ID: fmt.Sprintf("group%d", i),
|
|
||||||
Primary: i == 0,
|
|
||||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
|
||||||
SearchDomainsEnabled: true,
|
|
||||||
NameServers: []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
Port: 53,
|
|
||||||
NSType: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildJWTConfig_Audiences(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
authAudience string
|
|
||||||
cliAuthAudience string
|
|
||||||
expectedAudiences []string
|
|
||||||
expectedAudience string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "only_auth_audience",
|
|
||||||
authAudience: "dashboard-aud",
|
|
||||||
cliAuthAudience: "",
|
|
||||||
expectedAudiences: []string{"dashboard-aud"},
|
|
||||||
expectedAudience: "dashboard-aud",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "both_audiences_different",
|
|
||||||
authAudience: "dashboard-aud",
|
|
||||||
cliAuthAudience: "cli-aud",
|
|
||||||
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
|
|
||||||
expectedAudience: "cli-aud",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "both_audiences_same",
|
|
||||||
authAudience: "same-aud",
|
|
||||||
cliAuthAudience: "same-aud",
|
|
||||||
expectedAudiences: []string{"same-aud"},
|
|
||||||
expectedAudience: "same-aud",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
config := &nbconfig.HttpServerConfig{
|
|
||||||
AuthIssuer: "https://issuer.example.com",
|
|
||||||
AuthAudience: tc.authAudience,
|
|
||||||
CLIAuthAudience: tc.cliAuthAudience,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := buildJWTConfig(config, nil)
|
|
||||||
|
|
||||||
assert.NotNil(t, result)
|
|
||||||
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
|
||||||
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
|
|
||||||
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"hash/fnv"
|
|
||||||
"math"
|
|
||||||
"math/rand"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testAdvancedCfg() *lfConfig {
|
|
||||||
return &lfConfig{
|
|
||||||
reconnThreshold: 50 * time.Millisecond,
|
|
||||||
baseBlockDuration: 100 * time.Millisecond,
|
|
||||||
reconnLimitForBan: 3,
|
|
||||||
metaChangeLimit: 2,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoginFilterTestSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
filter *loginFilter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) SetupTest() {
|
|
||||||
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoginFilterTestSuite(t *testing.T) {
|
|
||||||
suite.Run(t, new(LoginFilterTestSuite))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta := uint64(1)
|
|
||||||
|
|
||||||
s.True(s.filter.allowLogin(pubKey, meta))
|
|
||||||
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta := uint64(1)
|
|
||||||
limit := s.filter.cfg.reconnLimitForBan
|
|
||||||
|
|
||||||
for i := 0; i <= limit; i++ {
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.False(s.filter.allowLogin(pubKey, meta))
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.True(s.filter.logged[pubKey].isBanned)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta := uint64(1)
|
|
||||||
limit := s.filter.cfg.reconnLimitForBan
|
|
||||||
baseBan := s.filter.cfg.baseBlockDuration
|
|
||||||
|
|
||||||
for i := 0; i <= limit; i++ {
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
}
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.True(s.filter.logged[pubKey].isBanned)
|
|
||||||
s.Equal(1, s.filter.logged[pubKey].banLevel)
|
|
||||||
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
|
||||||
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
|
|
||||||
|
|
||||||
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
|
|
||||||
s.filter.logged[pubKey].isBanned = false
|
|
||||||
|
|
||||||
for i := 0; i <= limit; i++ {
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
}
|
|
||||||
s.True(s.filter.logged[pubKey].isBanned)
|
|
||||||
s.Equal(2, s.filter.logged[pubKey].banLevel)
|
|
||||||
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
|
||||||
// nolint
|
|
||||||
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
|
|
||||||
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta := uint64(1)
|
|
||||||
|
|
||||||
s.filter.logged[pubKey] = &peerState{
|
|
||||||
isBanned: true,
|
|
||||||
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
|
|
||||||
}
|
|
||||||
|
|
||||||
s.True(s.filter.allowLogin(pubKey, meta))
|
|
||||||
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.False(s.filter.logged[pubKey].isBanned)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta := uint64(1)
|
|
||||||
|
|
||||||
s.filter.logged[pubKey] = &peerState{
|
|
||||||
currentHash: meta,
|
|
||||||
banLevel: 3,
|
|
||||||
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
|
|
||||||
}
|
|
||||||
|
|
||||||
s.filter.addLogin(pubKey, meta)
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.Equal(0, s.filter.logged[pubKey].banLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
limit := s.filter.cfg.metaChangeLimit
|
|
||||||
|
|
||||||
for i := range limit {
|
|
||||||
s.filter.addLogin(pubKey, uint64(i+1))
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
|
|
||||||
|
|
||||||
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
|
|
||||||
|
|
||||||
s.False(isAllowed, "should block new meta hash after limit is reached")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
|
|
||||||
pubKey := "PUB_KEY_A"
|
|
||||||
meta1 := uint64(1)
|
|
||||||
meta2 := uint64(2)
|
|
||||||
meta3 := uint64(3)
|
|
||||||
|
|
||||||
s.filter.addLogin(pubKey, meta1)
|
|
||||||
s.filter.addLogin(pubKey, meta2)
|
|
||||||
s.Require().Contains(s.filter.logged, pubKey)
|
|
||||||
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
|
|
||||||
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
|
|
||||||
|
|
||||||
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
|
|
||||||
|
|
||||||
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
|
|
||||||
|
|
||||||
s.filter.addLogin(pubKey, meta3)
|
|
||||||
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkHashingMethods(b *testing.B) {
|
|
||||||
meta := nbpeer.PeerSystemMeta{
|
|
||||||
WtVersion: "1.25.1",
|
|
||||||
OSVersion: "Ubuntu 22.04.3 LTS",
|
|
||||||
KernelVersion: "5.15.0-76-generic",
|
|
||||||
Hostname: "prod-server-database-01",
|
|
||||||
SystemSerialNumber: "PC-1234567890",
|
|
||||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
|
||||||
}
|
|
||||||
pubip := "8.8.8.8"
|
|
||||||
|
|
||||||
var resultString string
|
|
||||||
var resultUint uint64
|
|
||||||
|
|
||||||
b.Run("BuilderString", func(b *testing.B) {
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
resultString = builderString(meta, pubip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("FnvHashToString", func(b *testing.B) {
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
resultString = fnvHashToString(meta, pubip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
resultUint = metaHash(meta, pubip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
_ = resultString
|
|
||||||
_ = resultUint
|
|
||||||
}
|
|
||||||
|
|
||||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
|
||||||
h := fnv.New64a()
|
|
||||||
|
|
||||||
if len(meta.NetworkAddresses) != 0 {
|
|
||||||
for _, na := range meta.NetworkAddresses {
|
|
||||||
h.Write([]byte(na.Mac))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Write([]byte(meta.WtVersion))
|
|
||||||
h.Write([]byte(meta.OSVersion))
|
|
||||||
h.Write([]byte(meta.KernelVersion))
|
|
||||||
h.Write([]byte(meta.Hostname))
|
|
||||||
h.Write([]byte(meta.SystemSerialNumber))
|
|
||||||
h.Write([]byte(pubip))
|
|
||||||
|
|
||||||
return strconv.FormatUint(h.Sum64(), 16)
|
|
||||||
}
|
|
||||||
|
|
||||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
|
||||||
mac := getMacAddress(meta.NetworkAddresses)
|
|
||||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
|
||||||
len(pubip) + len(mac) + 6
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.Grow(estimatedSize)
|
|
||||||
|
|
||||||
b.WriteString(meta.WtVersion)
|
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(meta.OSVersion)
|
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(meta.KernelVersion)
|
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(meta.Hostname)
|
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(meta.SystemSerialNumber)
|
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(pubip)
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
|
||||||
if len(nas) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
macs := make([]string, 0, len(nas))
|
|
||||||
for _, na := range nas {
|
|
||||||
macs = append(macs, na.Mac)
|
|
||||||
}
|
|
||||||
return strings.Join(macs, "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
|
||||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
|
||||||
numKeys := 100000
|
|
||||||
pubKeys := make([]string, numKeys)
|
|
||||||
for i := range numKeys {
|
|
||||||
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
||||||
|
|
||||||
for pb.Next() {
|
|
||||||
key := pubKeys[r.Intn(numKeys)]
|
|
||||||
meta := r.Uint64()
|
|
||||||
|
|
||||||
if filter.allowLogin(key, meta) {
|
|
||||||
filter.addLogin(key, meta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
ip := "192.168.1.1"
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure("192.168.1.1")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, l.isLimited("192.168.1.1"))
|
|
||||||
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
ip := "10.0.0.1"
|
|
||||||
|
|
||||||
// Exhaust burst
|
|
||||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
|
||||||
l.recordFailure(ip)
|
|
||||||
}
|
|
||||||
require.True(t, l.isLimited(ip))
|
|
||||||
|
|
||||||
// Wait for token replenishment
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
l.recordFailure("10.0.0.1")
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
require.Len(t, l.limiters, 1)
|
|
||||||
// Backdate the entry so it looks stale
|
|
||||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
|
||||||
l.mu.Unlock()
|
|
||||||
|
|
||||||
l.cleanup()
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
|
||||||
l.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
|
||||||
l := newAuthFailureLimiter()
|
|
||||||
defer l.stop()
|
|
||||||
|
|
||||||
l.recordFailure("10.0.0.1")
|
|
||||||
l.recordFailure("10.0.0.2")
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
// Only backdate one entry
|
|
||||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
|
||||||
l.mu.Unlock()
|
|
||||||
|
|
||||||
l.cleanup()
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
|
||||||
assert.Contains(t, l.limiters, "10.0.0.2")
|
|
||||||
l.mu.Unlock()
|
|
||||||
}
|
|
||||||
@@ -1,381 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockReverseProxyManager struct {
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
return m.proxiesByAccount[accountID], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return []*reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return &reverseproxy.Service{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockUsersManager struct {
|
|
||||||
users map[string]*types.User
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
|
||||||
if m.err != nil {
|
|
||||||
return nil, m.err
|
|
||||||
}
|
|
||||||
user, ok := m.users[userID]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("user not found")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
domain string
|
|
||||||
userID string
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
users map[string]*types.User
|
|
||||||
proxyErr error
|
|
||||||
userErr error
|
|
||||||
expectErr bool
|
|
||||||
expectErrMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "user not found",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "unknown-user",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "user not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy not found in user's account",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "service not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy exists in different account - not accessible",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "service not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no bearer auth configured - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bearer auth disabled - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user not in allowed groups",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "not in allowed groups",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user in one of the allowed groups - allow access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user in all allowed groups - allow access",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{
|
|
||||||
Domain: "app.example.com",
|
|
||||||
AccountID: "account1",
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"group1", "group2"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy manager error",
|
|
||||||
domain: "app.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: nil,
|
|
||||||
proxyErr: errors.New("database error"),
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
expectErrMsg: "get account services",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple proxies in account - finds correct one",
|
|
||||||
domain: "app2.example.com",
|
|
||||||
userID: "user1",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {
|
|
||||||
{Domain: "app1.example.com", AccountID: "account1"},
|
|
||||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
|
||||||
{Domain: "app3.example.com", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
users: map[string]*types.User{
|
|
||||||
"user1": {Id: "user1", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
server := &ProxyServiceServer{
|
|
||||||
serviceManager: &mockReverseProxyManager{
|
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
|
||||||
err: tt.proxyErr,
|
|
||||||
},
|
|
||||||
usersManager: &mockUsersManager{
|
|
||||||
users: tt.users,
|
|
||||||
err: tt.userErr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accountID string
|
|
||||||
domain string
|
|
||||||
proxiesByAccount map[string][]*reverseproxy.Service
|
|
||||||
err error
|
|
||||||
expectProxy bool
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "proxy found",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {
|
|
||||||
{Domain: "other.example.com", AccountID: "account1"},
|
|
||||||
{Domain: "app.example.com", AccountID: "account1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectProxy: true,
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy not found in account",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "unknown.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
|
||||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
|
||||||
},
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty proxy list for account",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "manager error",
|
|
||||||
accountID: "account1",
|
|
||||||
domain: "app.example.com",
|
|
||||||
proxiesByAccount: nil,
|
|
||||||
err: errors.New("database error"),
|
|
||||||
expectProxy: false,
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
server := &ProxyServiceServer{
|
|
||||||
serviceManager: &mockReverseProxyManager{
|
|
||||||
proxiesByAccount: tt.proxiesByAccount,
|
|
||||||
err: tt.err,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, proxy)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, proxy)
|
|
||||||
assert.Equal(t, tt.domain, proxy.Domain)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
|
||||||
// and returns the channel where messages will be received.
|
|
||||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
|
||||||
ch := make(chan *proto.ProxyMapping, 10)
|
|
||||||
conn := &proxyConnection{
|
|
||||||
proxyID: proxyID,
|
|
||||||
address: clusterAddr,
|
|
||||||
sendChan: ch,
|
|
||||||
}
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
|
||||||
|
|
||||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
|
||||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
|
||||||
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
|
||||||
select {
|
|
||||||
case msg := <-ch:
|
|
||||||
return msg
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
|
||||||
const numProxies = 3
|
|
||||||
|
|
||||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
|
||||||
for i := range numProxies {
|
|
||||||
id := "proxy-" + string(rune('a'+i))
|
|
||||||
channels[i] = registerFakeProxy(s, id, cluster)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
Path: []*proto.PathMapping{
|
|
||||||
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
|
||||||
|
|
||||||
tokens := make([]string, numProxies)
|
|
||||||
for i, ch := range channels {
|
|
||||||
msg := drainChannel(ch)
|
|
||||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
|
||||||
assert.Equal(t, update.Domain, msg.Domain)
|
|
||||||
assert.Equal(t, update.Id, msg.Id)
|
|
||||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
|
||||||
tokens[i] = msg.AuthToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// All tokens must be unique
|
|
||||||
tokenSet := make(map[string]struct{})
|
|
||||||
for i, tok := range tokens {
|
|
||||||
_, exists := tokenSet[tok]
|
|
||||||
assert.False(t, exists, "proxy %d got duplicate token", i)
|
|
||||||
tokenSet[tok] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each token must be independently consumable
|
|
||||||
for i, tok := range tokens {
|
|
||||||
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
|
||||||
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
|
||||||
msg2 := drainChannel(ch2)
|
|
||||||
require.NotNil(t, msg1)
|
|
||||||
require.NotNil(t, msg2)
|
|
||||||
|
|
||||||
// Delete operations should not generate tokens
|
|
||||||
assert.Empty(t, msg1.AuthToken)
|
|
||||||
assert.Empty(t, msg2.AuthToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
tokenStore: tokenStore,
|
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
||||||
Id: "service-1",
|
|
||||||
AccountId: "account-1",
|
|
||||||
Domain: "test.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
s.SendServiceUpdate(update)
|
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
|
||||||
msg2 := drainChannel(ch2)
|
|
||||||
require.NotNil(t, msg1)
|
|
||||||
require.NotNil(t, msg2)
|
|
||||||
|
|
||||||
assert.NotEmpty(t, msg1.AuthToken)
|
|
||||||
assert.NotEmpty(t, msg2.AuthToken)
|
|
||||||
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
|
||||||
|
|
||||||
// Both tokens should validate
|
|
||||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
|
||||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateState creates a state using the same format as GetOIDCURL.
|
|
||||||
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
|
||||||
nonce := make([]byte, 16)
|
|
||||||
_, _ = rand.Read(nonce)
|
|
||||||
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
|
||||||
|
|
||||||
payload := redirectURL + "|" + nonceB64
|
|
||||||
hmacSum := s.generateHMAC(payload)
|
|
||||||
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURL := "https://app.example.com/callback"
|
|
||||||
|
|
||||||
// Generate 100 states for the same redirect URL
|
|
||||||
states := make(map[string]bool)
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
state := generateState(s, redirectURL)
|
|
||||||
|
|
||||||
// State must have 3 parts: base64(url)|nonce|hmac
|
|
||||||
parts := strings.Split(state, "|")
|
|
||||||
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
|
||||||
|
|
||||||
// State must be unique
|
|
||||||
require.False(t, states[state], "state %d is a duplicate", i)
|
|
||||||
states[state] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Old format had only 2 parts: base64(url)|hmac
|
|
||||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
|
||||||
|
|
||||||
_, _, err := s.ValidateState("base64url|hmac")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "invalid state format")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
|
||||||
s := &ProxyServiceServer{
|
|
||||||
oidcConfig: ProxyOIDCConfig{
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store with tampered HMAC
|
|
||||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
|
||||||
|
|
||||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "invalid state signature")
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|
||||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
inputFlow *config.DeviceAuthorizationFlow
|
|
||||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
|
||||||
expectedErrFunc require.ErrorAssertionFunc
|
|
||||||
expectedErrMSG string
|
|
||||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
|
||||||
expectedComparisonMSG string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Testing No Device Flow Config",
|
|
||||||
inputFlow: nil,
|
|
||||||
expectedErrFunc: require.Error,
|
|
||||||
expectedErrMSG: "should return error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Testing Invalid Device Flow Provider Config",
|
|
||||||
inputFlow: &config.DeviceAuthorizationFlow{
|
|
||||||
Provider: "NoNe",
|
|
||||||
ProviderConfig: config.ProviderConfig{
|
|
||||||
ClientID: "test",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedErrFunc: require.Error,
|
|
||||||
expectedErrMSG: "should return error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Testing Full Device Flow Config",
|
|
||||||
inputFlow: &config.DeviceAuthorizationFlow{
|
|
||||||
Provider: "hosted",
|
|
||||||
ProviderConfig: config.ProviderConfig{
|
|
||||||
ClientID: "test",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
|
||||||
Provider: 0,
|
|
||||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
|
||||||
ClientID: "test",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedErrFunc: require.NoError,
|
|
||||||
expectedErrMSG: "should not return error",
|
|
||||||
expectedComparisonFunc: require.Equal,
|
|
||||||
expectedComparisonMSG: "should match",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
mgmtServer := &Server{
|
|
||||||
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
|
|
||||||
config: &config.Config{
|
|
||||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
|
||||||
key, err := mgmtServer.secretsManager.GetWGKey()
|
|
||||||
require.NoError(t, err, "should be able to get server key")
|
|
||||||
|
|
||||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
|
|
||||||
require.NoError(t, err, "should be able to encrypt message")
|
|
||||||
|
|
||||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
|
||||||
context.TODO(),
|
|
||||||
&mgmtProto.EncryptedMessage{
|
|
||||||
WgPubKey: testingClientKey.PublicKey().String(),
|
|
||||||
Body: encryptedMSG,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
|
||||||
if testCase.expectedComparisonFunc != nil {
|
|
||||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
|
||||||
|
|
||||||
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
|
||||||
require.NoError(t, err, "should be able to decrypt")
|
|
||||||
|
|
||||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
|
||||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"hash"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
var TurnTestHost = &config.Host{
|
|
||||||
Proto: config.UDP,
|
|
||||||
URI: "turn:turn.netbird.io:77777",
|
|
||||||
Username: "username",
|
|
||||||
Password: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
|
||||||
ttl := util.Duration{Duration: time.Hour}
|
|
||||||
secret := "some_secret"
|
|
||||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
|
||||||
|
|
||||||
rc := &config.Relay{
|
|
||||||
Addresses: []string{"localhost:0"},
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
Turns: []*config.Host{TurnTestHost},
|
|
||||||
TimeBasedCredentials: true,
|
|
||||||
}, rc, settingsMockManager, groupsManager)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
turnCredentials, err := tested.GenerateTurnToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
if turnCredentials.Payload == "" {
|
|
||||||
t.Errorf("expected generated TURN username not to be empty, got empty")
|
|
||||||
}
|
|
||||||
if turnCredentials.Signature == "" {
|
|
||||||
t.Errorf("expected generated TURN password not to be empty, got empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
|
|
||||||
|
|
||||||
relayCredentials, err := tested.GenerateRelayToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
if relayCredentials.Payload == "" {
|
|
||||||
t.Errorf("expected generated relay payload not to be empty, got empty")
|
|
||||||
}
|
|
||||||
if relayCredentials.Signature == "" {
|
|
||||||
t.Errorf("expected generated relay signature not to be empty, got empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
hashedSecret := sha256.Sum256([]byte(secret))
|
|
||||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
|
||||||
ttl := util.Duration{Duration: 2 * time.Second}
|
|
||||||
secret := "some_secret"
|
|
||||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
|
||||||
peer := "some_peer"
|
|
||||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
|
||||||
|
|
||||||
rc := &config.Relay{
|
|
||||||
Addresses: []string{"localhost:0"},
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
Turns: []*config.Host{TurnTestHost},
|
|
||||||
TimeBasedCredentials: true,
|
|
||||||
}, rc, settingsMockManager, groupsManager)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tested.SetupRefresh(ctx, "someAccountID", peer)
|
|
||||||
|
|
||||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
|
||||||
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
|
||||||
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
|
|
||||||
}
|
|
||||||
|
|
||||||
var updates []*network_map.UpdateMessage
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for timeout := time.After(5 * time.Second); ; {
|
|
||||||
select {
|
|
||||||
case update := <-updateChannel:
|
|
||||||
updates = append(updates, update)
|
|
||||||
case <-timeout:
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(updates) >= 2 {
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(updates) < 2 {
|
|
||||||
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
|
|
||||||
}
|
|
||||||
|
|
||||||
var turnUpdates, relayUpdates int
|
|
||||||
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
|
|
||||||
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
|
|
||||||
|
|
||||||
for _, update := range updates {
|
|
||||||
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
|
|
||||||
turnUpdates++
|
|
||||||
if turnUpdates == 1 {
|
|
||||||
firstTurnUpdate = turns[0]
|
|
||||||
} else {
|
|
||||||
secondTurnUpdate = turns[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
|
|
||||||
// avoid updating on turn updates since they also send relay credentials
|
|
||||||
if update.Update.GetNetbirdConfig().GetTurns() == nil {
|
|
||||||
relayUpdates++
|
|
||||||
if relayUpdates == 1 {
|
|
||||||
firstRelayUpdate = relay
|
|
||||||
} else {
|
|
||||||
secondRelayUpdate = relay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if turnUpdates < 1 {
|
|
||||||
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
|
|
||||||
}
|
|
||||||
if relayUpdates < 1 {
|
|
||||||
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstTurnUpdate != nil && secondTurnUpdate != nil {
|
|
||||||
if firstTurnUpdate.Password == secondTurnUpdate.Password {
|
|
||||||
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if firstRelayUpdate != nil && secondRelayUpdate != nil {
|
|
||||||
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
|
|
||||||
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
|
||||||
ttl := util.Duration{Duration: time.Hour}
|
|
||||||
secret := "some_secret"
|
|
||||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
|
||||||
peer := "some_peer"
|
|
||||||
|
|
||||||
rc := &config.Relay{
|
|
||||||
Addresses: []string{"localhost:0"},
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
|
||||||
CredentialsTTL: ttl,
|
|
||||||
Secret: secret,
|
|
||||||
Turns: []*config.Host{TurnTestHost},
|
|
||||||
TimeBasedCredentials: true,
|
|
||||||
}, rc, settingsMockManager, groupsManager)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
|
||||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
|
||||||
t.Errorf("expecting peer to be present in turn cancel map, got not present")
|
|
||||||
}
|
|
||||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
|
||||||
t.Errorf("expecting peer to be present in relay cancel map, got not present")
|
|
||||||
}
|
|
||||||
|
|
||||||
tested.CancelRefresh(peer)
|
|
||||||
if _, ok := tested.turnCancelMap[peer]; ok {
|
|
||||||
t.Errorf("expecting peer to be not present in turn cancel map, got present")
|
|
||||||
}
|
|
||||||
if _, ok := tested.relayCancelMap[peer]; ok {
|
|
||||||
t.Errorf("expecting peer to be not present in relay cancel map, got present")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
|
|
||||||
t.Helper()
|
|
||||||
mac := hmac.New(algo, key)
|
|
||||||
|
|
||||||
_, err := mac.Write([]byte(username))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedMAC := mac.Sum(nil)
|
|
||||||
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
equal := hmac.Equal(decodedMAC, expectedMAC)
|
|
||||||
|
|
||||||
if !equal {
|
|
||||||
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,587 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldSend := debouncer.ProcessUpdate(update)
|
|
||||||
|
|
||||||
if !shouldSend {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
if debouncer.TimerChannel() == nil {
|
|
||||||
t.Error("Timer should be started after first update")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update should be sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update1) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rapid subsequent updates should be coalesced
|
|
||||||
if debouncer.ProcessUpdate(update2) {
|
|
||||||
t.Error("Second rapid update should not be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
if debouncer.ProcessUpdate(update3) {
|
|
||||||
t.Error("Third rapid update should not be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update3 {
|
|
||||||
t.Error("Should get the last update (update3)")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Send second update within debounce period
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait for timer
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update2 {
|
|
||||||
t.Error("Should get the last update")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] == update1 {
|
|
||||||
t.Error("Should not get the first update")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Wait a bit, but not the full debounce period
|
|
||||||
time.Sleep(30 * time.Millisecond)
|
|
||||||
|
|
||||||
// Send second update - should reset timer
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait a bit more
|
|
||||||
time.Sleep(30 * time.Millisecond)
|
|
||||||
|
|
||||||
// Send third update - should reset timer again
|
|
||||||
debouncer.ProcessUpdate(update3)
|
|
||||||
|
|
||||||
// Now wait for the timer (should fire after last update's reset)
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != update3 {
|
|
||||||
t.Error("Should get the last update (update3)")
|
|
||||||
}
|
|
||||||
// Timer should be restarted since there was a pending update
|
|
||||||
if debouncer.TimerChannel() == nil {
|
|
||||||
t.Error("Timer should be restarted after sending pending update")
|
|
||||||
}
|
|
||||||
case <-time.After(150 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update3 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(update1)
|
|
||||||
|
|
||||||
// Second update coalesced
|
|
||||||
debouncer.ProcessUpdate(update2)
|
|
||||||
|
|
||||||
// Wait for timer to expire
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) == 0 {
|
|
||||||
t.Fatal("Should have pending update")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After sending pending update, timer is restarted, so next update is NOT immediate
|
|
||||||
if debouncer.ProcessUpdate(update3) {
|
|
||||||
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the restarted timer and verify update3 is pending
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
finalUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
|
||||||
t.Error("Should get update3 as pending")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired for restarted timer")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send update to start timer
|
|
||||||
debouncer.ProcessUpdate(update)
|
|
||||||
|
|
||||||
// Stop should clean up
|
|
||||||
debouncer.Stop()
|
|
||||||
|
|
||||||
// Multiple stops should be safe
|
|
||||||
debouncer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate high-frequency updates
|
|
||||||
var lastUpdate *network_map.UpdateMessage
|
|
||||||
sentImmediately := 0
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
lastUpdate = update
|
|
||||||
if debouncer.ProcessUpdate(update) {
|
|
||||||
sentImmediately++
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only first update should be sent immediately
|
|
||||||
if sentImmediately != 1 {
|
|
||||||
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0] != lastUpdate {
|
|
||||||
t.Error("Should get the very last update")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
|
||||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send first update
|
|
||||||
if !debouncer.ProcessUpdate(update) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for timer to expire with no additional updates (true quiet period)
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) != 0 {
|
|
||||||
t.Error("Should have no pending updates")
|
|
||||||
}
|
|
||||||
// After true quiet period, timer should be cleared
|
|
||||||
if debouncer.TimerChannel() != nil {
|
|
||||||
t.Error("Timer should be cleared after quiet period")
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
updates := make([]*network_map.UpdateMessage, 5)
|
|
||||||
for i := range updates {
|
|
||||||
updates[i] = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(updates[0])
|
|
||||||
|
|
||||||
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
|
||||||
debouncer.ProcessUpdate(updates[1])
|
|
||||||
debouncer.ProcessUpdate(updates[2])
|
|
||||||
debouncer.ProcessUpdate(updates[3])
|
|
||||||
debouncer.ProcessUpdate(updates[4])
|
|
||||||
|
|
||||||
// Wait for debounce
|
|
||||||
<-debouncer.TimerChannel()
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) != 1 {
|
|
||||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
|
||||||
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
update1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
update2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update1) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for timer without sending any more updates (true quiet period)
|
|
||||||
<-debouncer.TimerChannel()
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
|
|
||||||
if len(pendingUpdates) != 0 {
|
|
||||||
t.Error("Should have no pending updates during quiet period")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After true quiet period, next update should be sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update2) {
|
|
||||||
t.Error("Update after true quiet period should be sent immediately")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate continuous high-frequency updates
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
update := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
|
||||||
Serial: uint64(i),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
// First one sent immediately
|
|
||||||
if !debouncer.ProcessUpdate(update) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// All others should be coalesced (not sent immediately)
|
|
||||||
if debouncer.ProcessUpdate(update) {
|
|
||||||
t.Errorf("Update %d should not be sent immediately", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait a bit but send next update before debounce expires
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now wait for final debounce
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
if len(pendingUpdates) == 0 {
|
|
||||||
t.Fatal("Should have the last update pending")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
|
||||||
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
netmapUpdate := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
tokenUpdate1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
tokenUpdate2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate)
|
|
||||||
|
|
||||||
// Send multiple control config updates - they should all be queued
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate1)
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate2)
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get both control config updates
|
|
||||||
if len(pendingUpdates) != 2 {
|
|
||||||
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// Control configs should come first
|
|
||||||
if pendingUpdates[0] != tokenUpdate1 {
|
|
||||||
t.Error("First pending update should be tokenUpdate1")
|
|
||||||
}
|
|
||||||
if pendingUpdates[1] != tokenUpdate2 {
|
|
||||||
t.Error("Second pending update should be tokenUpdate2")
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
netmapUpdate1 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
netmapUpdate2 := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
tokenUpdate := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update sent immediately
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate1)
|
|
||||||
|
|
||||||
// Send token update and network map update
|
|
||||||
debouncer.ProcessUpdate(tokenUpdate)
|
|
||||||
debouncer.ProcessUpdate(netmapUpdate2)
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get 2 updates in order: token, then network map
|
|
||||||
if len(pendingUpdates) != 2 {
|
|
||||||
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// Token update should come first (preserves order)
|
|
||||||
if pendingUpdates[0] != tokenUpdate {
|
|
||||||
t.Error("First pending update should be tokenUpdate")
|
|
||||||
}
|
|
||||||
// Network map update should come second
|
|
||||||
if pendingUpdates[1] != netmapUpdate2 {
|
|
||||||
t.Error("Second pending update should be netmapUpdate2")
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
|
||||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
|
||||||
defer debouncer.Stop()
|
|
||||||
|
|
||||||
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
|
||||||
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
|
||||||
|
|
||||||
// Send first network map immediately
|
|
||||||
firstNetmap := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
if !debouncer.ProcessUpdate(firstNetmap) {
|
|
||||||
t.Error("First update should be sent immediately")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send 49 more network maps (will be coalesced to last one)
|
|
||||||
var lastNetmapBatch1 *network_map.UpdateMessage
|
|
||||||
for i := 1; i < 50; i++ {
|
|
||||||
lastNetmapBatch1 = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(lastNetmapBatch1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send 1 control config
|
|
||||||
controlConfig := &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
|
||||||
MessageType: network_map.MessageTypeControlConfig,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(controlConfig)
|
|
||||||
|
|
||||||
// Send 50 more network maps (will be coalesced to last one)
|
|
||||||
var lastNetmapBatch2 *network_map.UpdateMessage
|
|
||||||
for i := 50; i < 100; i++ {
|
|
||||||
lastNetmapBatch2 = &network_map.UpdateMessage{
|
|
||||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
|
||||||
MessageType: network_map.MessageTypeNetworkMap,
|
|
||||||
}
|
|
||||||
debouncer.ProcessUpdate(lastNetmapBatch2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for debounce period
|
|
||||||
select {
|
|
||||||
case <-debouncer.TimerChannel():
|
|
||||||
pendingUpdates := debouncer.GetPendingUpdates()
|
|
||||||
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
|
||||||
if len(pendingUpdates) != 3 {
|
|
||||||
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
|
||||||
}
|
|
||||||
// First should be the last netmap from batch 1
|
|
||||||
if pendingUpdates[0] != lastNetmapBatch1 {
|
|
||||||
t.Error("First pending update should be last netmap from batch 1")
|
|
||||||
}
|
|
||||||
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
|
||||||
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
// Second should be the control config
|
|
||||||
if pendingUpdates[1] != controlConfig {
|
|
||||||
t.Error("Second pending update should be control config")
|
|
||||||
}
|
|
||||||
// Third should be the last netmap from batch 2
|
|
||||||
if pendingUpdates[2] != lastNetmapBatch2 {
|
|
||||||
t.Error("Third pending update should be last netmap from batch 2")
|
|
||||||
}
|
|
||||||
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
|
||||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Error("Timer should have fired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,307 +0,0 @@
|
|||||||
//go:build integration
|
|
||||||
|
|
||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/proxy/auth"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type validateSessionTestSetup struct {
|
|
||||||
proxyService *ProxyServiceServer
|
|
||||||
store store.Store
|
|
||||||
cleanup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
|
||||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
|
||||||
|
|
||||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager)
|
|
||||||
proxyService.SetProxyManager(proxyManager)
|
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
|
||||||
|
|
||||||
return &validateSessionTestSetup{
|
|
||||||
proxyService: proxyService,
|
|
||||||
store: testStore,
|
|
||||||
cleanup: storeCleanup,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
pubKey, privKey := generateSessionKeyPair(t)
|
|
||||||
|
|
||||||
testProxy := &reverseproxy.Service{
|
|
||||||
ID: "testProxyId",
|
|
||||||
AccountID: "testAccountId",
|
|
||||||
Name: "Test Proxy",
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
Enabled: true,
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
|
||||||
|
|
||||||
restrictedProxy := &reverseproxy.Service{
|
|
||||||
ID: "restrictedProxyId",
|
|
||||||
AccountID: "testAccountId",
|
|
||||||
Name: "Restricted Proxy",
|
|
||||||
Domain: "restricted-proxy.example.com",
|
|
||||||
Enabled: true,
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
Auth: reverseproxy.AuthConfig{
|
|
||||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DistributionGroups: []string{"allowedGroupId"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateSessionKeyPair(t *testing.T) (string, string) {
|
|
||||||
t.Helper()
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
|
||||||
t.Helper()
|
|
||||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserAllowed(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, resp.Valid, "User should be allowed access")
|
|
||||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
|
||||||
assert.Empty(t, resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "restricted-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
|
||||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
|
||||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "User in different account should be denied")
|
|
||||||
assert.Equal(t, "account_mismatch", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_UserNotFound(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Non-existent user should be denied")
|
|
||||||
assert.Equal(t, "user_not_found", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_ProxyNotFound(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "unknown-proxy.example.com",
|
|
||||||
SessionToken: token,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
|
||||||
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_InvalidToken(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
SessionToken: "invalid-token",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid, "Invalid token should be denied")
|
|
||||||
assert.Equal(t, "invalid_token", resp.DeniedReason)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_MissingDomain(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
SessionToken: "some-token",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid)
|
|
||||||
assert.Contains(t, resp.DeniedReason, "missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateSession_MissingToken(t *testing.T) {
|
|
||||||
setup := setupValidateSessionTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
|
||||||
Domain: "test-proxy.example.com",
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, resp.Valid)
|
|
||||||
assert.Contains(t, resp.DeniedReason, "missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
type testValidateSessionProxyManager struct {
|
|
||||||
store store.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testValidateSessionUsersManager struct {
|
|
||||||
store store.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
|
||||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockMappingStream struct {
|
|
||||||
grpc.ClientStream
|
|
||||||
messages []*proto.GetMappingUpdateResponse
|
|
||||||
idx int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
|
||||||
if m.idx >= len(m.messages) {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
msg := m.messages[m.idx]
|
|
||||||
m.idx++
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockMappingStream) Header() (metadata.MD, error) {
|
|
||||||
return nil, nil //nolint:nilnil
|
|
||||||
}
|
|
||||||
func (m *mockMappingStream) Trailer() metadata.MD { return nil }
|
|
||||||
func (m *mockMappingStream) CloseSend() error { return nil }
|
|
||||||
func (m *mockMappingStream) Context() context.Context { return context.Background() }
|
|
||||||
func (m *mockMappingStream) SendMsg(any) error { return nil }
|
|
||||||
func (m *mockMappingStream) RecvMsg(any) error { return nil }
|
|
||||||
|
|
||||||
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
|
|
||||||
checker := health.NewChecker(nil, nil)
|
|
||||||
s := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
healthChecker: checker,
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := &mockMappingStream{
|
|
||||||
messages: []*proto.GetMappingUpdateResponse{
|
|
||||||
{InitialSyncComplete: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
syncDone := false
|
|
||||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
|
|
||||||
checker := health.NewChecker(nil, nil)
|
|
||||||
s := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
healthChecker: checker,
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := &mockMappingStream{
|
|
||||||
messages: []*proto.GetMappingUpdateResponse{
|
|
||||||
{}, // no sync flag
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
syncDone := false
|
|
||||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.False(t, syncDone, "initial sync should not be marked done without flag")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
|
|
||||||
s := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := &mockMappingStream{
|
|
||||||
messages: []*proto.GetMappingUpdateResponse{
|
|
||||||
{InitialSyncComplete: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
syncDone := false
|
|
||||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, syncDone, "sync done flag should be set even without health checker")
|
|
||||||
}
|
|
||||||
@@ -1,561 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
|
||||||
proxytypes "github.com/netbirdio/netbird/proxy/internal/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// integrationTestSetup contains all real components for testing.
|
|
||||||
type integrationTestSetup struct {
|
|
||||||
store store.Store
|
|
||||||
proxyService *nbgrpc.ProxyServiceServer
|
|
||||||
grpcServer *grpc.Server
|
|
||||||
grpcAddr string
|
|
||||||
cleanup func()
|
|
||||||
services []*reverseproxy.Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Create real SQLite store
|
|
||||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create test account
|
|
||||||
testAccount := &types.Account{
|
|
||||||
Id: "test-account-1",
|
|
||||||
Domain: "test.com",
|
|
||||||
DomainCategory: "private",
|
|
||||||
IsDomainPrimaryAccount: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
|
||||||
|
|
||||||
// Generate session keys for reverse proxies
|
|
||||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
require.NoError(t, err)
|
|
||||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
|
||||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
|
||||||
|
|
||||||
// Create test services in the store
|
|
||||||
services := []*reverseproxy.Service{
|
|
||||||
{
|
|
||||||
ID: "rp-1",
|
|
||||||
AccountID: "test-account-1",
|
|
||||||
Name: "Test App 1",
|
|
||||||
Domain: "app1.test.proxy.io",
|
|
||||||
Targets: []*reverseproxy.Target{{
|
|
||||||
Path: strPtr("/"),
|
|
||||||
Host: "10.0.0.1",
|
|
||||||
Port: 8080,
|
|
||||||
Protocol: "http",
|
|
||||||
TargetId: "peer1",
|
|
||||||
TargetType: "peer",
|
|
||||||
Enabled: true,
|
|
||||||
}},
|
|
||||||
Enabled: true,
|
|
||||||
ProxyCluster: "test.proxy.io",
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "rp-2",
|
|
||||||
AccountID: "test-account-1",
|
|
||||||
Name: "Test App 2",
|
|
||||||
Domain: "app2.test.proxy.io",
|
|
||||||
Targets: []*reverseproxy.Target{{
|
|
||||||
Path: strPtr("/"),
|
|
||||||
Host: "10.0.0.2",
|
|
||||||
Port: 8080,
|
|
||||||
Protocol: "http",
|
|
||||||
TargetId: "peer2",
|
|
||||||
TargetType: "peer",
|
|
||||||
Enabled: true,
|
|
||||||
}},
|
|
||||||
Enabled: true,
|
|
||||||
ProxyCluster: "test.proxy.io",
|
|
||||||
SessionPrivateKey: privKey,
|
|
||||||
SessionPublicKey: pubKey,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, svc := range services {
|
|
||||||
require.NoError(t, testStore.CreateService(ctx, svc))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create real token store
|
|
||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create real users manager
|
|
||||||
usersManager := users.NewManager(testStore)
|
|
||||||
|
|
||||||
// Create real proxy service server with minimal config
|
|
||||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
|
||||||
Issuer: "https://fake-issuer.example.com",
|
|
||||||
ClientID: "test-client",
|
|
||||||
HMACKey: []byte("test-hmac-key"),
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(
|
|
||||||
&testAccessLogManager{},
|
|
||||||
tokenStore,
|
|
||||||
oidcConfig,
|
|
||||||
nil,
|
|
||||||
usersManager,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Use store-backed service manager
|
|
||||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
|
||||||
proxyService.SetProxyManager(svcMgr)
|
|
||||||
|
|
||||||
// Start real gRPC server
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
grpcServer := grpc.NewServer()
|
|
||||||
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := grpcServer.Serve(lis); err != nil {
|
|
||||||
t.Logf("gRPC server error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return &integrationTestSetup{
|
|
||||||
store: testStore,
|
|
||||||
proxyService: proxyService,
|
|
||||||
grpcServer: grpcServer,
|
|
||||||
grpcAddr: lis.Addr().String(),
|
|
||||||
services: services,
|
|
||||||
cleanup: func() {
|
|
||||||
grpcServer.GracefulStop()
|
|
||||||
cleanup()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// testAccessLogManager provides access log storage for testing.
|
|
||||||
type testAccessLogManager struct{}
|
|
||||||
|
|
||||||
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
|
||||||
// noop
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testAccessLogManager) StopPeriodicCleanup() {
|
|
||||||
// noop
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
|
||||||
return nil, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// storeBackedServiceManager reads directly from the real store.
|
|
||||||
type storeBackedServiceManager struct {
|
|
||||||
store store.Store
|
|
||||||
tokenStore *nbgrpc.OneTimeTokenStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, accountID string, targetID string) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func strPtr(s string) *string {
|
|
||||||
return &s
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
|
||||||
setup := setupIntegrationTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewProxyServiceClient(conn)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: "test-proxy-1",
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: "test.proxy.io",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
|
||||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
require.NoError(t, err)
|
|
||||||
for _, m := range msg.GetMapping() {
|
|
||||||
mappingsByID[m.GetId()] = m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should receive 2 mappings total
|
|
||||||
assert.Len(t, mappingsByID, 2, "Should receive 2 reverse proxy mappings")
|
|
||||||
|
|
||||||
rp1 := mappingsByID["rp-1"]
|
|
||||||
require.NotNil(t, rp1)
|
|
||||||
assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain())
|
|
||||||
assert.Equal(t, "test-account-1", rp1.GetAccountId())
|
|
||||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType())
|
|
||||||
assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token for peer creation")
|
|
||||||
|
|
||||||
rp2 := mappingsByID["rp-2"]
|
|
||||||
require.NotNil(t, rp2)
|
|
||||||
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
|
||||||
setup := setupIntegrationTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewProxyServiceClient(conn)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
clusterAddress := "test.proxy.io"
|
|
||||||
|
|
||||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: "test-proxy-cluster",
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
mappings := make([]*proto.ProxyMapping, 0)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
require.NoError(t, err)
|
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should receive the 2 mappings matching the cluster
|
|
||||||
assert.Len(t, mappings, 2, "Should receive mappings for the cluster")
|
|
||||||
|
|
||||||
for _, mapping := range mappings {
|
|
||||||
t.Logf("Received mapping: id=%s domain=%s", mapping.GetId(), mapping.GetDomain())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) {
|
|
||||||
setup := setupIntegrationTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewProxyServiceClient(conn)
|
|
||||||
|
|
||||||
clusterAddress := "test.proxy.io"
|
|
||||||
proxyID := "test-proxy-reconnect"
|
|
||||||
|
|
||||||
// Helper to receive all mappings from a stream
|
|
||||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
|
||||||
var mappings []*proto.ProxyMapping
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
require.NoError(t, err)
|
|
||||||
mappings = append(mappings, msg.GetMapping()...)
|
|
||||||
}
|
|
||||||
return mappings
|
|
||||||
}
|
|
||||||
|
|
||||||
// First connection
|
|
||||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
firstMappings := receiveMappings(stream1, 2)
|
|
||||||
cancel1()
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// Second connection (simulating reconnect)
|
|
||||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel2()
|
|
||||||
|
|
||||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
secondMappings := receiveMappings(stream2, 2)
|
|
||||||
|
|
||||||
// Should receive the same mappings
|
|
||||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
|
||||||
"Should receive same number of mappings on reconnect")
|
|
||||||
|
|
||||||
firstIDs := make(map[string]bool)
|
|
||||||
for _, m := range firstMappings {
|
|
||||||
firstIDs[m.GetId()] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range secondMappings {
|
|
||||||
assert.True(t, firstIDs[m.GetId()],
|
|
||||||
"Mapping %s should be present in both connections", m.GetId())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T) {
|
|
||||||
setup := setupIntegrationTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewProxyServiceClient(conn)
|
|
||||||
|
|
||||||
// Use real auth middleware and proxy to verify idempotency
|
|
||||||
logger := log.New()
|
|
||||||
logger.SetLevel(log.WarnLevel)
|
|
||||||
|
|
||||||
authMw := auth.NewMiddleware(logger, nil)
|
|
||||||
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
|
|
||||||
|
|
||||||
clusterAddress := "test.proxy.io"
|
|
||||||
proxyID := "test-proxy-idempotent"
|
|
||||||
|
|
||||||
var addMappingCalls atomic.Int32
|
|
||||||
|
|
||||||
applyMappings := func(mappings []*proto.ProxyMapping) {
|
|
||||||
for _, mapping := range mappings {
|
|
||||||
if mapping.GetType() == proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
|
|
||||||
addMappingCalls.Add(1)
|
|
||||||
|
|
||||||
// Apply to real auth middleware (idempotent)
|
|
||||||
err := authMw.AddDomain(
|
|
||||||
mapping.GetDomain(),
|
|
||||||
nil,
|
|
||||||
"",
|
|
||||||
0,
|
|
||||||
mapping.GetAccountId(),
|
|
||||||
mapping.GetId(),
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Apply to real proxy (idempotent)
|
|
||||||
proxyHandler.AddMapping(proxy.Mapping{
|
|
||||||
Host: mapping.GetDomain(),
|
|
||||||
ID: mapping.GetId(),
|
|
||||||
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to receive and apply all mappings
|
|
||||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
require.NoError(t, err)
|
|
||||||
applyMappings(msg.GetMapping())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// First connection
|
|
||||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
receiveAndApply(stream1)
|
|
||||||
cancel1()
|
|
||||||
|
|
||||||
firstCallCount := addMappingCalls.Load()
|
|
||||||
t.Logf("First connection: applied %d mappings", firstCallCount)
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// Second connection
|
|
||||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
receiveAndApply(stream2)
|
|
||||||
cancel2()
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// Third connection
|
|
||||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel3()
|
|
||||||
|
|
||||||
stream3, err := client.GetMappingUpdate(ctx3, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
receiveAndApply(stream3)
|
|
||||||
|
|
||||||
totalCalls := addMappingCalls.Load()
|
|
||||||
t.Logf("After three connections: total applied %d mappings", totalCalls)
|
|
||||||
|
|
||||||
// Should have called addMapping 6 times (2 mappings x 3 connections)
|
|
||||||
// But internal state is NOT duplicated because auth and proxy use maps keyed by domain/host
|
|
||||||
assert.Equal(t, int32(6), totalCalls, "Should have 6 total calls (2 mappings x 3 connections)")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) {
|
|
||||||
setup := setupIntegrationTest(t)
|
|
||||||
defer setup.cleanup()
|
|
||||||
|
|
||||||
clusterAddress := "test.proxy.io"
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var mu sync.Mutex
|
|
||||||
receivedByProxy := make(map[string]int)
|
|
||||||
|
|
||||||
for i := 1; i <= 3; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(proxyNum int) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewProxyServiceClient(conn)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
proxyID := "test-proxy-" + string(rune('A'+proxyNum-1))
|
|
||||||
|
|
||||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
|
||||||
ProxyId: proxyID,
|
|
||||||
Version: "test-v1",
|
|
||||||
Address: clusterAddress,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Receive all mappings - server sends each mapping individually
|
|
||||||
count := 0
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
require.NoError(t, err)
|
|
||||||
count += len(msg.GetMapping())
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
receivedByProxy[proxyID] = count
|
|
||||||
mu.Unlock()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
for proxyID, count := range receivedByProxy {
|
|
||||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
proxyproto "github.com/pires/go-proxyproto"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
|
||||||
srv := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
|
||||||
ProxyProtocol: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer raw.Close()
|
|
||||||
|
|
||||||
ln := srv.wrapProxyProtocol(raw)
|
|
||||||
|
|
||||||
realClientIP := "203.0.113.50"
|
|
||||||
realClientPort := uint16(54321)
|
|
||||||
|
|
||||||
accepted := make(chan net.Conn, 1)
|
|
||||||
go func() {
|
|
||||||
conn, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
accepted <- conn
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Connect and send a PROXY v2 header.
|
|
||||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
header := &proxyproto.Header{
|
|
||||||
Version: 2,
|
|
||||||
Command: proxyproto.PROXY,
|
|
||||||
TransportProtocol: proxyproto.TCPv4,
|
|
||||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
|
||||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
|
||||||
}
|
|
||||||
_, err = header.WriteTo(conn)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case accepted := <-accepted:
|
|
||||||
defer accepted.Close()
|
|
||||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timed out waiting for connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
|
||||||
srv := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := proxyproto.ConnPolicyOptions{
|
|
||||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
|
||||||
}
|
|
||||||
policy, err := srv.proxyProtocolPolicy(opts)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
|
||||||
srv := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := proxyproto.ConnPolicyOptions{
|
|
||||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
|
||||||
}
|
|
||||||
policy, err := srv.proxyProtocolPolicy(opts)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
|
||||||
srv := &Server{
|
|
||||||
Logger: log.StandardLogger(),
|
|
||||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := proxyproto.ConnPolicyOptions{
|
|
||||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
|
||||||
}
|
|
||||||
policy, err := srv.proxyProtocolPolicy(opts)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
|
||||||
}
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDebugEndpointDisabledByDefault(t *testing.T) {
|
|
||||||
s := &Server{}
|
|
||||||
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDebugEndpointAddr(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty defaults to localhost",
|
|
||||||
input: "",
|
|
||||||
expected: "localhost:8444",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "explicit localhost preserved",
|
|
||||||
input: "localhost:9999",
|
|
||||||
expected: "localhost:9999",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "explicit address preserved",
|
|
||||||
input: "0.0.0.0:8444",
|
|
||||||
expected: "0.0.0.0:8444",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "127.0.0.1 preserved",
|
|
||||||
input: "127.0.0.1:8444",
|
|
||||||
expected: "127.0.0.1:8444",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result := debugEndpointAddr(tc.input)
|
|
||||||
assert.Equal(t, tc.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseTrustedProxies(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
raw string
|
|
||||||
want []netip.Prefix
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty string returns nil",
|
|
||||||
raw: "",
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single CIDR",
|
|
||||||
raw: "10.0.0.0/8",
|
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single bare IPv4",
|
|
||||||
raw: "1.2.3.4",
|
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single bare IPv6",
|
|
||||||
raw: "::1",
|
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("::1/128")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "comma-separated CIDRs",
|
|
||||||
raw: "10.0.0.0/8, 192.168.1.0/24",
|
|
||||||
want: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed CIDRs and bare IPs",
|
|
||||||
raw: "10.0.0.0/8, 1.2.3.4, fd00::/8",
|
|
||||||
want: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
netip.MustParsePrefix("1.2.3.4/32"),
|
|
||||||
netip.MustParsePrefix("fd00::/8"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "whitespace around entries",
|
|
||||||
raw: " 10.0.0.0/8 , 192.168.0.0/16 ",
|
|
||||||
want: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
netip.MustParsePrefix("192.168.0.0/16"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing comma produces no extra entry",
|
|
||||||
raw: "10.0.0.0/8,",
|
|
||||||
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid entry",
|
|
||||||
raw: "not-an-ip",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "partially invalid",
|
|
||||||
raw: "10.0.0.0/8, garbage",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := ParseTrustedProxies(tt.raw)
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.want, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user