mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 07:06:41 +00:00
Compare commits
14 Commits
set-cmd
...
poc/prepro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51eee4c5ac | ||
|
|
8942c40fde | ||
|
|
fbb1b55beb | ||
|
|
77ec32dd6f | ||
|
|
8c09a55057 | ||
|
|
f603ddf35e | ||
|
|
996b8c600c | ||
|
|
c4ed11d447 | ||
|
|
9afbecb7ac | ||
|
|
2c81cf2c1e | ||
|
|
551cb4e467 | ||
|
|
57961afe95 | ||
|
|
22678bce7f | ||
|
|
6c633497bc |
@@ -15,6 +15,9 @@
|
|||||||
<a href="https://docs.netbird.io/slack-url">
|
<a href="https://docs.netbird.io/slack-url">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://forum.netbird.io">
|
||||||
|
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||||
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://gurubase.io/g/netbird">
|
<a href="https://gurubase.io/g/netbird">
|
||||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||||
@@ -29,13 +32,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||||
New: NetBird Kubernetes Operator
|
New: NetBird terraform provider
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
@@ -203,8 +203,10 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if routes[0].IsDynamic() {
|
r := routes[0]
|
||||||
continue
|
netStr := r.Network.String()
|
||||||
|
if r.IsDynamic() {
|
||||||
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
@@ -214,7 +216,7 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
}
|
}
|
||||||
network := Network{
|
network := Network{
|
||||||
Name: string(id),
|
Name: string(id),
|
||||||
Network: routes[0].Network.String(),
|
Network: netStr,
|
||||||
Peer: peer.FQDN,
|
Peer: peer.FQDN,
|
||||||
Status: peer.ConnStatus.String(),
|
Status: peer.ConnStatus.String(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,6 +104,12 @@ type Manager struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
blockRule firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
|
// Internal 1:1 DNAT
|
||||||
|
dnatEnabled atomic.Bool
|
||||||
|
dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
dnatMutex sync.RWMutex
|
||||||
|
dnatBiMap *biDNATMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return nil, errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
@@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// FilterOutBound filters outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData, size)
|
return m.filterOutbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// FilterInbound filters incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData, size)
|
return m.filterInbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// for netflow we keep track even if the firewall is stateless
|
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
|
m.translateOutboundDNAT(packetData, d)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -747,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// Re-decode after translation to get original addresses
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
srcIP, dstIP = m.extractIPs(d)
|
||||||
|
}
|
||||||
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut, 0)
|
manager.filterOutbound(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, 0)
|
manager.filterInbound(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), 0) {
|
if m.filterInbound(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
408
client/firewall/uspfilter/nat.go
Normal file
408
client/firewall/uspfilter/nat.go
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||||
|
|
||||||
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
|
if len(header) < 20 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum1, sum2 uint32
|
||||||
|
|
||||||
|
// Parallel processing - unroll and compute two sums simultaneously
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||||
|
// Skip checksum field at [10:12]
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||||
|
|
||||||
|
sum := sum1 + sum2
|
||||||
|
|
||||||
|
// Handle remaining bytes for headers > 20 bytes
|
||||||
|
for i := 20; i < len(header)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(header)%2 == 1 {
|
||||||
|
sum += uint32(header[len(header)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimized carry fold - single iteration handles most cases
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func icmpChecksum(data []byte) uint16 {
|
||||||
|
var sum1, sum2, sum3, sum4 uint32
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
// Process 16 bytes at once with 4 parallel accumulators
|
||||||
|
for i <= len(data)-16 {
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||||
|
i += 16
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sum1 + sum2 + sum3 + sum4
|
||||||
|
|
||||||
|
// Handle remaining bytes
|
||||||
|
for i < len(data)-1 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data)%2 == 1 {
|
||||||
|
sum += uint32(data[len(data)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
type biDNATMap struct {
|
||||||
|
forward map[netip.Addr]netip.Addr
|
||||||
|
reverse map[netip.Addr]netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBiDNATMap() *biDNATMap {
|
||||||
|
return &biDNATMap{
|
||||||
|
forward: make(map[netip.Addr]netip.Addr),
|
||||||
|
reverse: make(map[netip.Addr]netip.Addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||||
|
b.forward[original] = translated
|
||||||
|
b.reverse[translated] = original
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) delete(original netip.Addr) {
|
||||||
|
if translated, exists := b.forward[original]; exists {
|
||||||
|
delete(b.forward, original)
|
||||||
|
delete(b.reverse, translated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||||
|
translated, exists := b.forward[original]
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||||
|
original, exists := b.reverse[translated]
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||||
|
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||||
|
return fmt.Errorf("invalid IP addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||||
|
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize both maps together if either is nil
|
||||||
|
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||||
|
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
m.dnatBiMap = newBiDNATMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMappings[originalAddr] = translatedAddr
|
||||||
|
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||||
|
|
||||||
|
if len(m.dnatMappings) == 1 {
|
||||||
|
m.dnatEnabled.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||||
|
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||||
|
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.dnatMappings, originalAddr)
|
||||||
|
m.dnatBiMap.delete(originalAddr)
|
||||||
|
if len(m.dnatMappings) == 0 {
|
||||||
|
m.dnatEnabled.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDNATTranslation returns the translated address if a mapping exists
|
||||||
|
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return addr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// findReverseDNATMapping finds original address for return traffic
|
||||||
|
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return translatedAddr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||||
|
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||||
|
|
||||||
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||||
|
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||||
|
|
||||||
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketDestination replaces destination IP in the packet
|
||||||
|
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldDst [4]byte
|
||||||
|
copy(oldDst[:], packetData[16:20])
|
||||||
|
newDst := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[16:20], newDst[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketSource replaces the source IP address in the packet
|
||||||
|
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldSrc [4]byte
|
||||||
|
copy(oldSrc[:], packetData[12:16])
|
||||||
|
newSrc := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[12:16], newSrc[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
tcpStart := ipHeaderLen
|
||||||
|
if len(packetData) < tcpStart+18 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := tcpStart + 16
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
udpStart := ipHeaderLen
|
||||||
|
if len(packetData) < udpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := udpStart + 6
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
|
||||||
|
if oldChecksum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||||
|
icmpStart := ipHeaderLen
|
||||||
|
if len(packetData) < icmpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpData := packetData[icmpStart:]
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||||
|
checksum := icmpChecksum(icmpData)
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||||
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
|
sum := uint32(^oldChecksum)
|
||||||
|
|
||||||
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||||
|
} else {
|
||||||
|
// Fallback for other lengths
|
||||||
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(oldBytes)%2 == 1 {
|
||||||
|
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(newBytes)%2 == 1 {
|
||||||
|
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||||
|
func BenchmarkDNATTranslation(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
setupDNAT bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_with_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "TCP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_without_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "TCP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_with_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "UDP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_without_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "UDP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_with_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "ICMP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_without_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "ICMP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mapping if needed
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
if sc.setupDNAT {
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test packets
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
|
||||||
|
// Pre-establish connection for reverse DNAT test
|
||||||
|
if sc.setupDNAT {
|
||||||
|
manager.filterOutbound(outboundPacket, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
// Benchmark outbound DNAT translation
|
||||||
|
b.Run("outbound", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Benchmark inbound reverse DNAT translation
|
||||||
|
if sc.setupDNAT {
|
||||||
|
b.Run("inbound_reverse", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||||
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup multiple DNAT mappings
|
||||||
|
numMappings := 100
|
||||||
|
originalIPs := make([]netip.Addr, numMappings)
|
||||||
|
translatedIPs := make([]netip.Addr, numMappings)
|
||||||
|
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Pre-generate packets
|
||||||
|
outboundPackets := make([][]byte, numMappings)
|
||||||
|
inboundPackets := make([][]byte, numMappings)
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
// Establish connections
|
||||||
|
manager.filterOutbound(outboundPackets[i], 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||||
|
func BenchmarkDNATScaling(b *testing.B) {
|
||||||
|
mappingCounts := []int{1, 10, 100, 1000}
|
||||||
|
|
||||||
|
for _, count := range mappingCounts {
|
||||||
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mappings
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with the last mapping added (worst case for lookup)
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||||
|
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP.AsSlice(),
|
||||||
|
DstIP: dstIP.AsSlice(),
|
||||||
|
Protocol: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch proto {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = udp
|
||||||
|
case layers.IPProtocolICMPv4:
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||||
|
}
|
||||||
|
transportLayer = icmp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||||
|
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||||
|
// Create test data for checksum calculations
|
||||||
|
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||||
|
for i := range testData {
|
||||||
|
testData[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("icmp_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = icmpChecksum(testData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("incremental_update", func(b *testing.B) {
|
||||||
|
oldBytes := []byte{192, 168, 1, 100}
|
||||||
|
newBytes := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||||
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time to isolate allocation testing
|
||||||
|
testPacket := make([]byte, len(packet))
|
||||||
|
copy(testPacket, packet)
|
||||||
|
|
||||||
|
// Parse the packet fresh each time to get a clean decoder
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
manager.translateOutboundDNAT(testPacket, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||||
|
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||||
|
// Create a test packet
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.Run("direct_byte_access", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Direct extraction from packet bytes
|
||||||
|
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("decoder_extraction", func(b *testing.B) {
|
||||||
|
// Create decoder once for comparison
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Extract using decoder (traditional method)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
_ = dst
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||||
|
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||||
|
// Create test IPv4 header (20 bytes)
|
||||||
|
header := make([]byte, 20)
|
||||||
|
for i := range header {
|
||||||
|
header[i] = byte(i)
|
||||||
|
}
|
||||||
|
// Clear checksum field
|
||||||
|
header[10] = 0
|
||||||
|
header[11] = 0
|
||||||
|
|
||||||
|
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(header)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test incremental checksum updates
|
||||||
|
oldIP := []byte{192, 168, 1, 100}
|
||||||
|
newIP := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
145
client/firewall/uspfilter/nat_test.go
Normal file
145
client/firewall/uspfilter/nat_test.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||||
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Add DNAT mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
protocol layers.IPProtocol
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||||
|
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||||
|
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Test outbound DNAT translation
|
||||||
|
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||||
|
originalOutbound := make([]byte, len(outboundPacket))
|
||||||
|
copy(originalOutbound, outboundPacket)
|
||||||
|
|
||||||
|
// Process outbound packet (should translate destination)
|
||||||
|
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||||
|
require.True(t, translated, "Outbound packet should be translated")
|
||||||
|
|
||||||
|
// Verify destination IP was changed
|
||||||
|
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||||
|
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||||
|
|
||||||
|
// Test inbound reverse DNAT translation
|
||||||
|
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||||
|
originalInbound := make([]byte, len(inboundPacket))
|
||||||
|
copy(originalInbound, inboundPacket)
|
||||||
|
|
||||||
|
// Process inbound packet (should reverse translate source)
|
||||||
|
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||||
|
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||||
|
|
||||||
|
// Verify source IP was changed back to original
|
||||||
|
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||||
|
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||||
|
|
||||||
|
// Test that checksums are recalculated correctly
|
||||||
|
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||||
|
// For TCP/UDP, verify the transport checksum was updated
|
||||||
|
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||||
|
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePacket helper to create a decoder for testing
|
||||||
|
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||||
|
t.Helper()
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||||
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
// Test adding mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping exists
|
||||||
|
result, exists := manager.getDNATTranslation(originalIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, translatedIP, result)
|
||||||
|
|
||||||
|
// Test reverse lookup
|
||||||
|
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, originalIP, reverseResult)
|
||||||
|
|
||||||
|
// Test removing mapping
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping no longer exists
|
||||||
|
_, exists = manager.getDNATTranslation(originalIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
// Test error cases
|
||||||
|
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||||
|
require.Error(t, err, "Should reject invalid original IP")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||||
|
require.Error(t, err, "Should reject invalid translated IP")
|
||||||
|
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||||
|
}
|
||||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData, 0)
|
dropped := m.filterOutbound(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
94
client/iface/bind/activity.go
Normal file
94
client/iface/bind/activity.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
saveFrequency = int64(5 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerRecord struct {
|
||||||
|
Address netip.AddrPort
|
||||||
|
LastActivity atomic.Int64 // UnixNano timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActivityRecorder struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
||||||
|
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewActivityRecorder() *ActivityRecorder {
|
||||||
|
return &ActivityRecorder{
|
||||||
|
peers: make(map[string]*PeerRecord),
|
||||||
|
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastActivities returns a snapshot of peer last activity
|
||||||
|
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
activities := make(map[string]time.Time, len(r.peers))
|
||||||
|
for key, record := range r.peers {
|
||||||
|
unixNano := record.LastActivity.Load()
|
||||||
|
activities[key] = time.Unix(0, unixNano)
|
||||||
|
}
|
||||||
|
return activities
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertAddress adds or updates the address for a publicKey
|
||||||
|
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if pr, exists := r.peers[publicKey]; exists {
|
||||||
|
delete(r.addrToPeer, pr.Address)
|
||||||
|
pr.Address = address
|
||||||
|
} else {
|
||||||
|
record := &PeerRecord{
|
||||||
|
Address: address,
|
||||||
|
}
|
||||||
|
record.LastActivity.Store(monotime.Now())
|
||||||
|
r.peers[publicKey] = record
|
||||||
|
}
|
||||||
|
|
||||||
|
r.addrToPeer[address] = r.peers[publicKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if record, exists := r.peers[publicKey]; exists {
|
||||||
|
delete(r.addrToPeer, record.Address)
|
||||||
|
delete(r.peers, publicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record updates LastActivity for the given address using atomic store
|
||||||
|
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
||||||
|
r.mu.RLock()
|
||||||
|
record, ok := r.addrToPeer[address]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("could not find record for address %s", address)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := monotime.Now()
|
||||||
|
last := record.LastActivity.Load()
|
||||||
|
if now-last < saveFrequency {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = record.LastActivity.CompareAndSwap(last, now)
|
||||||
|
}
|
||||||
27
client/iface/bind/activity_test.go
Normal file
27
client/iface/bind/activity_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||||
|
peer := "peer1"
|
||||||
|
ar := NewActivityRecorder()
|
||||||
|
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
||||||
|
activities := ar.GetLastActivities()
|
||||||
|
|
||||||
|
p, ok := activities[peer]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.IsZero() {
|
||||||
|
t.Fatalf("Expected activity for peer %s, but got zero", peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Before(time.Now().Add(-2 * time.Minute)) {
|
||||||
|
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -54,6 +55,7 @@ type ICEBind struct {
|
|||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
@@ -67,6 +69,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
|
|||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
address: address,
|
address: address,
|
||||||
|
activityRecorder: NewActivityRecorder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -100,6 +103,10 @@ func (s *ICEBind) Close() error {
|
|||||||
return s.StdNetBind.Close()
|
return s.StdNetBind.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||||
|
return s.activityRecorder
|
||||||
|
}
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
@@ -199,6 +206,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
|
||||||
|
if isTransportPkg(msg.Buffers, msg.N) {
|
||||||
|
s.activityRecorder.record(addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
eps[i] = ep
|
eps[i] = ep
|
||||||
@@ -257,6 +269,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
copy(buffs[0], msg.Buffer)
|
copy(buffs[0], msg.Buffer)
|
||||||
sizes[0] = len(msg.Buffer)
|
sizes[0] = len(msg.Buffer)
|
||||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||||
|
|
||||||
|
if isTransportPkg(buffs, sizes[0]) {
|
||||||
|
if ep, ok := eps[0].(*Endpoint); ok {
|
||||||
|
c.activityRecorder.record(ep.AddrPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -272,3 +291,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
|||||||
}
|
}
|
||||||
msgsPool.Put(msgs)
|
msgsPool.Put(msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTransportPkg(buffers [][]byte, n int) bool {
|
||||||
|
// The first buffer should contain at least 4 bytes for type
|
||||||
|
if len(buffers[0]) < 4 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// WireGuard packet type is a little-endian uint32 at start
|
||||||
|
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
||||||
|
|
||||||
|
// Check if packetType matches known WireGuard message types
|
||||||
|
if packetType == 4 && n > 32 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -276,3 +276,7 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
|||||||
}
|
}
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) LastActivities() map[string]time.Time {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,14 +39,16 @@ var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
|||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
device *device.Device
|
device *device.Device
|
||||||
deviceName string
|
deviceName string
|
||||||
|
activityRecorder *bind.ActivityRecorder
|
||||||
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||||
wgCfg := &WGUSPConfigurer{
|
wgCfg := &WGUSPConfigurer{
|
||||||
device: device,
|
device: device,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
|
activityRecorder: activityRecorder,
|
||||||
}
|
}
|
||||||
wgCfg.startUAPI()
|
wgCfg.startUAPI()
|
||||||
return wgCfg
|
return wgCfg
|
||||||
@@ -87,7 +90,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||||
|
return ipcErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if endpoint != nil {
|
||||||
|
addr, err := netip.ParseAddr(endpoint.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||||
|
}
|
||||||
|
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||||
|
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
@@ -104,7 +119,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
|||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
||||||
|
|
||||||
|
c.activityRecorder.Remove(peerKey)
|
||||||
|
return ipcErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
@@ -205,6 +223,10 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
|||||||
return parseStatus(c.deviceName, ipcStr)
|
return parseStatus(c.deviceName, ipcStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time {
|
||||||
|
return c.activityRecorder.GetLastActivities()
|
||||||
|
}
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
func (t *WGUSPConfigurer) startUAPI() {
|
func (t *WGUSPConfigurer) startUAPI() {
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
|
|
||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// FilterOutbound filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte, size int) bool
|
FilterOutbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// FilterInbound filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte, size int) bool
|
FilterInbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -19,4 +19,5 @@ type WGConfigurer interface {
|
|||||||
Close()
|
Close()
|
||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
|
LastActivities() map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -217,6 +217,14 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
|||||||
return w.configurer.GetStats()
|
return w.configurer.GetStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) LastActivities() map[string]time.Time {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
return w.configurer.LastActivities()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||||
return w.configurer.FullStats()
|
return w.configurer.FullStats()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
|||||||
@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
// SetNetwork mocks base method.
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -29,8 +28,8 @@ type ConnMgr struct {
|
|||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
iface lazyconn.WGIface
|
iface lazyconn.WGIface
|
||||||
dispatcher *dispatcher.ConnectionDispatcher
|
|
||||||
enabledLocally bool
|
enabledLocally bool
|
||||||
|
rosenpassEnabled bool
|
||||||
|
|
||||||
lazyConnMgr *manager.Manager
|
lazyConnMgr *manager.Manager
|
||||||
|
|
||||||
@@ -39,12 +38,12 @@ type ConnMgr struct {
|
|||||||
lazyCtxCancel context.CancelFunc
|
lazyCtxCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
|
||||||
e := &ConnMgr{
|
e := &ConnMgr{
|
||||||
peerStore: peerStore,
|
peerStore: peerStore,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
iface: iface,
|
iface: iface,
|
||||||
dispatcher: dispatcher,
|
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||||
}
|
}
|
||||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||||
e.enabledLocally = true
|
e.enabledLocally = true
|
||||||
@@ -64,6 +63,11 @@ func (e *ConnMgr) Start(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.rosenpassEnabled {
|
||||||
|
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
e.initLazyManager(ctx)
|
e.initLazyManager(ctx)
|
||||||
e.statusRecorder.UpdateLazyConnection(true)
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
}
|
}
|
||||||
@@ -83,7 +87,12 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("lazy connection manager is enabled by management feature flag")
|
if e.rosenpassEnabled {
|
||||||
|
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("lazy connection manager is enabled by management feature flag")
|
||||||
e.initLazyManager(ctx)
|
e.initLazyManager(ctx)
|
||||||
e.statusRecorder.UpdateLazyConnection(true)
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
return e.addPeersToLazyConnManager()
|
return e.addPeersToLazyConnManager()
|
||||||
@@ -133,7 +142,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
|
|||||||
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
|
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
|
||||||
for _, peerID := range added {
|
for _, peerID := range added {
|
||||||
var peerConn *peer.Conn
|
var peerConn *peer.Conn
|
||||||
var exists bool
|
var exists bool
|
||||||
@@ -175,7 +184,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
|
|||||||
PeerConnID: conn.ConnID(),
|
PeerConnID: conn.ConnID(),
|
||||||
Log: conn.Log,
|
Log: conn.Log,
|
||||||
}
|
}
|
||||||
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
|
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||||
if err := conn.Open(ctx); err != nil {
|
if err := conn.Open(ctx); err != nil {
|
||||||
@@ -201,7 +210,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close(false)
|
||||||
|
|
||||||
if !e.isStartedWithLazyMgr() {
|
if !e.isStartedWithLazyMgr() {
|
||||||
return
|
return
|
||||||
@@ -211,23 +220,28 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
|||||||
conn.Log.Infof("removed peer from lazy conn manager")
|
conn.Log.Infof("removed peer from lazy conn manager")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
|
||||||
conn, ok := e.peerStore.PeerConn(peerKey)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !e.isStartedWithLazyMgr() {
|
if !e.isStartedWithLazyMgr() {
|
||||||
return conn, true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
|
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
||||||
conn.Log.Infof("activated peer from inactive state")
|
conn.Log.Infof("activated peer from inactive state")
|
||||||
if err := conn.Open(ctx); err != nil {
|
if err := conn.Open(ctx); err != nil {
|
||||||
conn.Log.Errorf("failed to open connection: %v", err)
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return conn, true
|
}
|
||||||
|
|
||||||
|
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
|
||||||
|
// If locally the lazy connection is disabled, we force the peer connection open.
|
||||||
|
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
|
||||||
|
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConnMgr) Close() {
|
func (e *ConnMgr) Close() {
|
||||||
@@ -244,7 +258,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
|
|||||||
cfg := manager.Config{
|
cfg := manager.Config{
|
||||||
InactivityThreshold: inactivityThresholdEnv(),
|
InactivityThreshold: inactivityThresholdEnv(),
|
||||||
}
|
}
|
||||||
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
|
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
|
||||||
|
|
||||||
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
@@ -275,7 +289,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
|
|||||||
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
|
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConnMgr) closeManager(ctx context.Context) {
|
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||||
|
|||||||
@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ import (
|
|||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
@@ -176,7 +175,6 @@ type Engine struct {
|
|||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
|
||||||
|
|
||||||
firewall firewallManager.Manager
|
firewall firewallManager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
@@ -383,7 +381,13 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.stateManager.Start()
|
e.stateManager.Start()
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
|
if err != nil {
|
||||||
|
e.close()
|
||||||
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
@@ -400,6 +404,7 @@ func (e *Engine) Start() error {
|
|||||||
InitialRoutes: initialRoutes,
|
InitialRoutes: initialRoutes,
|
||||||
StateManager: e.stateManager,
|
StateManager: e.stateManager,
|
||||||
DNSServer: dnsServer,
|
DNSServer: dnsServer,
|
||||||
|
DNSFeatureFlag: dnsFeatureFlag,
|
||||||
PeerStore: e.peerStore,
|
PeerStore: e.peerStore,
|
||||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
@@ -451,9 +456,7 @@ func (e *Engine) Start() error {
|
|||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||||
|
|
||||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
|
||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
|
|
||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
@@ -488,9 +491,9 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.config.BlockLANAccess {
|
if e.config.BlockLANAccess {
|
||||||
@@ -1009,8 +1012,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
|
||||||
|
|
||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||||
@@ -1021,6 +1022,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update routes: %v", err)
|
log.Errorf("failed to update routes: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1255,7 +1257,7 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||||
conn.Close()
|
conn.Close(false)
|
||||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1308,7 +1310,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
RelayManager: e.relayManager,
|
RelayManager: e.relayManager,
|
||||||
SrWatcher: e.srWatcher,
|
SrWatcher: e.srWatcher,
|
||||||
Semaphore: e.connSemaphore,
|
Semaphore: e.connSemaphore,
|
||||||
PeerConnDispatcher: e.peerConnDispatcher,
|
|
||||||
}
|
}
|
||||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1331,11 +1332,16 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
msgType := msg.GetBody().GetType()
|
||||||
|
if msgType != sProto.Body_GO_IDLE {
|
||||||
|
e.connMgr.ActivatePeer(e.ctx, conn)
|
||||||
|
}
|
||||||
|
|
||||||
switch msg.GetBody().Type {
|
switch msg.GetBody().Type {
|
||||||
case sProto.Body_OFFER:
|
case sProto.Body_OFFER:
|
||||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||||
@@ -1392,6 +1398,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
|
|
||||||
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
|
case sProto.Body_GO_IDLE:
|
||||||
|
e.connMgr.DeactivatePeer(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1489,7 +1497,12 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
info := system.GetInfo(e.ctx)
|
info := system.GetInfo(e.ctx)
|
||||||
info.SetFlags(
|
info.SetFlags(
|
||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
@@ -1506,11 +1519,12 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||||
return routes, &dnsCfg, nil
|
dnsFeatureFlag := toDNSFeatureFlag(netMap)
|
||||||
|
return routes, &dnsCfg, dnsFeatureFlag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||||
@@ -1558,18 +1572,14 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||||
// due to tests where we are using a mocked version of the DNS server
|
// due to tests where we are using a mocked version of the DNS server
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return nil, e.dnsServer, nil
|
return e.dnsServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "android":
|
case "android":
|
||||||
routes, dnsConfig, err := e.readInitialSettings()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||||
e.ctx,
|
e.ctx,
|
||||||
e.wgInterface,
|
e.wgInterface,
|
||||||
@@ -1580,19 +1590,19 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
)
|
)
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
return routes, dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return nil, dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, dnsServer, nil
|
return dnsServer, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@@ -97,6 +96,7 @@ type MockWGIface struct {
|
|||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
GetNetFunc func() *netstack.Net
|
GetNetFunc func() *netstack.Net
|
||||||
|
LastActivitiesFunc func() map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||||
@@ -187,6 +187,13 @@ func (m *MockWGIface) GetNet() *netstack.Net {
|
|||||||
return m.GetNetFunc()
|
return m.GetNetFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) LastActivities() map[string]time.Time {
|
||||||
|
if m.LastActivitiesFunc != nil {
|
||||||
|
return m.LastActivitiesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
_ = util.InitLog("debug", "console")
|
_ = util.InitLog("debug", "console")
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
@@ -404,7 +411,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||||
engine.connMgr.Start(ctx)
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -793,7 +800,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
|
|
||||||
engine.routeManager = mockRouteManager
|
engine.routeManager = mockRouteManager
|
||||||
engine.dnsServer = &dns.MockServer{}
|
engine.dnsServer = &dns.MockServer{}
|
||||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||||
engine.connMgr.Start(ctx)
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -991,7 +998,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
engine.dnsServer = mockDNSServer
|
engine.dnsServer = mockDNSServer
|
||||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||||
engine.connMgr.Start(ctx)
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -38,4 +38,5 @@ type wgIfaceBase interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
|
LastActivities() map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
wgIface lazyconn.WGIface
|
wgIface WgInterface
|
||||||
peerCfg lazyconn.PeerConfig
|
peerCfg lazyconn.PeerConfig
|
||||||
conn *net.UDPConn
|
conn *net.UDPConn
|
||||||
endpoint *net.UDPAddr
|
endpoint *net.UDPAddr
|
||||||
@@ -22,7 +22,7 @@ type Listener struct {
|
|||||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||||
d := &Listener{
|
d := &Listener{
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
peerCfg: cfg,
|
peerCfg: cfg,
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
package activity
|
package activity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type WgInterface interface {
|
||||||
|
RemovePeer(peerKey string) error
|
||||||
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
OnActivityChan chan peerid.ConnID
|
OnActivityChan chan peerid.ConnID
|
||||||
|
|
||||||
wgIface lazyconn.WGIface
|
wgIface WgInterface
|
||||||
|
|
||||||
peers map[peerid.ConnID]*Listener
|
peers map[peerid.ConnID]*Listener
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
@@ -20,7 +29,7 @@ type Manager struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
func NewManager(wgIface WgInterface) *Manager {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
package inactivity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
|
||||||
MinimumInactivityThreshold = 3 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
type Monitor struct {
|
|
||||||
id peer.ConnID
|
|
||||||
timer *time.Timer
|
|
||||||
cancel context.CancelFunc
|
|
||||||
inactivityThreshold time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
|
||||||
i := &Monitor{
|
|
||||||
id: peerID,
|
|
||||||
timer: time.NewTimer(0),
|
|
||||||
inactivityThreshold: threshold,
|
|
||||||
}
|
|
||||||
i.timer.Stop()
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
|
||||||
i.timer.Reset(i.inactivityThreshold)
|
|
||||||
defer i.timer.Stop()
|
|
||||||
|
|
||||||
ctx, i.cancel = context.WithCancel(ctx)
|
|
||||||
defer func() {
|
|
||||||
defer i.cancel()
|
|
||||||
select {
|
|
||||||
case <-i.timer.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-i.timer.C:
|
|
||||||
select {
|
|
||||||
case timeoutChan <- i.id:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Monitor) Stop() {
|
|
||||||
if i.cancel == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
i.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Monitor) PauseTimer() {
|
|
||||||
i.timer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Monitor) ResetTimer() {
|
|
||||||
i.timer.Reset(i.inactivityThreshold)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
|
|
||||||
i.Stop()
|
|
||||||
go i.Start(ctx, timeoutChan)
|
|
||||||
}
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
package inactivity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MocPeer struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
|
||||||
return peerid.ConnID(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInactivityMonitor(t *testing.T) {
|
|
||||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer testTimeoutCancel()
|
|
||||||
|
|
||||||
p := &MocPeer{}
|
|
||||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
|
||||||
|
|
||||||
timeoutChan := make(chan peerid.ConnID)
|
|
||||||
|
|
||||||
exitChan := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(exitChan)
|
|
||||||
im.Start(tCtx, timeoutChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
case <-tCtx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-exitChan:
|
|
||||||
case <-tCtx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReuseInactivityMonitor(t *testing.T) {
|
|
||||||
p := &MocPeer{}
|
|
||||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
|
||||||
|
|
||||||
timeoutChan := make(chan peerid.ConnID)
|
|
||||||
|
|
||||||
for i := 2; i > 0; i-- {
|
|
||||||
exitChan := make(chan struct{})
|
|
||||||
|
|
||||||
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(exitChan)
|
|
||||||
im.Start(testTimeoutCtx, timeoutChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
case <-testTimeoutCtx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-exitChan:
|
|
||||||
case <-testTimeoutCtx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
testTimeoutCancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStopInactivityMonitor(t *testing.T) {
|
|
||||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer testTimeoutCancel()
|
|
||||||
|
|
||||||
p := &MocPeer{}
|
|
||||||
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
|
||||||
|
|
||||||
timeoutChan := make(chan peerid.ConnID)
|
|
||||||
|
|
||||||
exitChan := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(exitChan)
|
|
||||||
im.Start(tCtx, timeoutChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
im.Stop()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
t.Fatal("unexpected timeout")
|
|
||||||
case <-exitChan:
|
|
||||||
case <-tCtx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPauseInactivityMonitor(t *testing.T) {
|
|
||||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer testTimeoutCancel()
|
|
||||||
|
|
||||||
p := &MocPeer{}
|
|
||||||
trashHold := time.Second * 3
|
|
||||||
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
timeoutChan := make(chan peerid.ConnID)
|
|
||||||
|
|
||||||
exitChan := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(exitChan)
|
|
||||||
im.Start(ctx, timeoutChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Second) // grant time to start the monitor
|
|
||||||
im.PauseTimer()
|
|
||||||
|
|
||||||
// check to do not receive timeout
|
|
||||||
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
|
||||||
defer thresholdCancel()
|
|
||||||
select {
|
|
||||||
case <-exitChan:
|
|
||||||
t.Fatal("unexpected exit")
|
|
||||||
case <-timeoutChan:
|
|
||||||
t.Fatal("unexpected timeout")
|
|
||||||
case <-thresholdCtx.Done():
|
|
||||||
// test ok
|
|
||||||
case <-tCtx.Done():
|
|
||||||
t.Fatal("test timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test reset timer
|
|
||||||
im.ResetTimer()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-tCtx.Done():
|
|
||||||
t.Fatal("test timed out")
|
|
||||||
case <-exitChan:
|
|
||||||
t.Fatal("unexpected exit")
|
|
||||||
case <-timeoutChan:
|
|
||||||
// expected timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
152
client/internal/lazyconn/inactivity/manager.go
Normal file
152
client/internal/lazyconn/inactivity/manager.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package inactivity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
checkInterval = 1 * time.Minute
|
||||||
|
|
||||||
|
DefaultInactivityThreshold = 15 * time.Minute
|
||||||
|
MinimumInactivityThreshold = 1 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type WgInterface interface {
|
||||||
|
LastActivities() map[string]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
inactivePeersChan chan map[string]struct{}
|
||||||
|
|
||||||
|
iface WgInterface
|
||||||
|
interestedPeers map[string]*lazyconn.PeerConfig
|
||||||
|
inactivityThreshold time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
|
||||||
|
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
|
||||||
|
if err != nil {
|
||||||
|
inactivityThreshold = DefaultInactivityThreshold
|
||||||
|
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
|
||||||
|
return &Manager{
|
||||||
|
inactivePeersChan: make(chan map[string]struct{}, 1),
|
||||||
|
iface: iface,
|
||||||
|
interestedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||||
|
inactivityThreshold: inactivityThreshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
|
||||||
|
if m == nil {
|
||||||
|
// return a nil channel that blocks forever
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.inactivePeersChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peerCfg.Log.Infof("adding peer to inactivity manager")
|
||||||
|
m.interestedPeers[peerCfg.PublicKey] = peerCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RemovePeer(peer string) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pi, ok := m.interestedPeers[peer]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pi.Log.Debugf("remove peer from inactivity manager")
|
||||||
|
delete(m.interestedPeers, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(ctx context.Context) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := newTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C():
|
||||||
|
idlePeers, err := m.checkStats()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error checking stats: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(idlePeers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m.notifyInactivePeers(ctx, idlePeers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
|
||||||
|
select {
|
||||||
|
case m.inactivePeersChan <- inactivePeers:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkStats() (map[string]struct{}, error) {
|
||||||
|
lastActivities := m.iface.LastActivities()
|
||||||
|
|
||||||
|
idlePeers := make(map[string]struct{})
|
||||||
|
|
||||||
|
for peerID, peerCfg := range m.interestedPeers {
|
||||||
|
lastActive, ok := lastActivities[peerID]
|
||||||
|
if !ok {
|
||||||
|
// when peer is in connecting state
|
||||||
|
peerCfg.Log.Warnf("peer not found in wg stats")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Since(lastActive) > m.inactivityThreshold {
|
||||||
|
peerCfg.Log.Infof("peer is inactive since: %v", lastActive)
|
||||||
|
idlePeers[peerID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return idlePeers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
|
||||||
|
if configuredThreshold == nil {
|
||||||
|
return DefaultInactivityThreshold, nil
|
||||||
|
}
|
||||||
|
if *configuredThreshold < MinimumInactivityThreshold {
|
||||||
|
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
|
||||||
|
}
|
||||||
|
return *configuredThreshold, nil
|
||||||
|
}
|
||||||
113
client/internal/lazyconn/inactivity/manager_test.go
Normal file
113
client/internal/lazyconn/inactivity/manager_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package inactivity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockWgInterface struct {
|
||||||
|
lastActivities map[string]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockWgInterface) LastActivities() map[string]time.Time {
|
||||||
|
return m.lastActivities
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerTriggersInactivity(t *testing.T) {
|
||||||
|
peerID := "peer1"
|
||||||
|
|
||||||
|
wgMock := &mockWgInterface{
|
||||||
|
lastActivities: map[string]time.Time{
|
||||||
|
peerID: time.Now().Add(-20 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeTick := make(chan time.Time, 1)
|
||||||
|
newTicker = func(d time.Duration) Ticker {
|
||||||
|
return &fakeTickerMock{CChan: fakeTick}
|
||||||
|
}
|
||||||
|
|
||||||
|
peerLog := log.WithField("peer", peerID)
|
||||||
|
peerCfg := &lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
Log: peerLog,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := NewManager(wgMock, nil)
|
||||||
|
manager.AddPeer(peerCfg)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the manager in a goroutine
|
||||||
|
go manager.Start(ctx)
|
||||||
|
|
||||||
|
// Send a tick to simulate time passage
|
||||||
|
fakeTick <- time.Now()
|
||||||
|
|
||||||
|
// Check if peer appears on inactivePeersChan
|
||||||
|
select {
|
||||||
|
case inactivePeers := <-manager.inactivePeersChan:
|
||||||
|
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("expected inactivity event, but none received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerTriggersActivity(t *testing.T) {
|
||||||
|
peerID := "peer1"
|
||||||
|
|
||||||
|
wgMock := &mockWgInterface{
|
||||||
|
lastActivities: map[string]time.Time{
|
||||||
|
peerID: time.Now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeTick := make(chan time.Time, 1)
|
||||||
|
newTicker = func(d time.Duration) Ticker {
|
||||||
|
return &fakeTickerMock{CChan: fakeTick}
|
||||||
|
}
|
||||||
|
|
||||||
|
peerLog := log.WithField("peer", peerID)
|
||||||
|
peerCfg := &lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
Log: peerLog,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := NewManager(wgMock, nil)
|
||||||
|
manager.AddPeer(peerCfg)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the manager in a goroutine
|
||||||
|
go manager.Start(ctx)
|
||||||
|
|
||||||
|
// Send a tick to simulate time passage
|
||||||
|
fakeTick <- time.Now()
|
||||||
|
|
||||||
|
// Check if peer appears on inactivePeersChan
|
||||||
|
select {
|
||||||
|
case <-manager.inactivePeersChan:
|
||||||
|
t.Fatal("expected inactive peer to be marked inactive")
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
// No inactivity event should be received
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeTickerMock implements Ticker interface for testing
|
||||||
|
type fakeTickerMock struct {
|
||||||
|
CChan chan time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTickerMock) C() <-chan time.Time {
|
||||||
|
return f.CChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTickerMock) Stop() {}
|
||||||
24
client/internal/lazyconn/inactivity/ticker.go
Normal file
24
client/internal/lazyconn/inactivity/ticker.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package inactivity
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
var newTicker = func(d time.Duration) Ticker {
|
||||||
|
return &realTicker{t: time.NewTicker(d)}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ticker interface {
|
||||||
|
C() <-chan time.Time
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
type realTicker struct {
|
||||||
|
t *time.Ticker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *realTicker) C() <-chan time.Time {
|
||||||
|
return r.t.C
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *realTicker) Stop() {
|
||||||
|
r.t.Stop()
|
||||||
|
}
|
||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -43,59 +42,45 @@ type Config struct {
|
|||||||
type Manager struct {
|
type Manager struct {
|
||||||
engineCtx context.Context
|
engineCtx context.Context
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
connStateDispatcher *dispatcher.ConnectionDispatcher
|
|
||||||
inactivityThreshold time.Duration
|
inactivityThreshold time.Duration
|
||||||
|
|
||||||
connStateListener *dispatcher.ConnectionListener
|
|
||||||
managedPeers map[string]*lazyconn.PeerConfig
|
managedPeers map[string]*lazyconn.PeerConfig
|
||||||
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||||
excludes map[string]lazyconn.PeerConfig
|
excludes map[string]lazyconn.PeerConfig
|
||||||
managedPeersMu sync.Mutex
|
managedPeersMu sync.Mutex
|
||||||
|
|
||||||
activityManager *activity.Manager
|
activityManager *activity.Manager
|
||||||
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
inactivityManager *inactivity.Manager
|
||||||
|
|
||||||
// Route HA group management
|
// Route HA group management
|
||||||
|
// If any peer in the same HA group is active, all peers in that group should prevent going idle
|
||||||
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
|
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
|
||||||
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
|
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
|
||||||
routesMu sync.RWMutex
|
routesMu sync.RWMutex
|
||||||
|
|
||||||
onInactive chan peerid.ConnID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new lazy connection manager
|
// NewManager creates a new lazy connection manager
|
||||||
// engineCtx is the context for creating peer Connection
|
// engineCtx is the context for creating peer Connection
|
||||||
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
|
||||||
log.Infof("setup lazy connection service")
|
log.Infof("setup lazy connection service")
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
engineCtx: engineCtx,
|
engineCtx: engineCtx,
|
||||||
peerStore: peerStore,
|
peerStore: peerStore,
|
||||||
connStateDispatcher: connStateDispatcher,
|
|
||||||
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||||
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||||
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||||
excludes: make(map[string]lazyconn.PeerConfig),
|
excludes: make(map[string]lazyconn.PeerConfig),
|
||||||
activityManager: activity.NewManager(wgIface),
|
activityManager: activity.NewManager(wgIface),
|
||||||
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
|
||||||
peerToHAGroups: make(map[string][]route.HAUniqueID),
|
peerToHAGroups: make(map[string][]route.HAUniqueID),
|
||||||
haGroupToPeers: make(map[route.HAUniqueID][]string),
|
haGroupToPeers: make(map[route.HAUniqueID][]string),
|
||||||
onInactive: make(chan peerid.ConnID),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.InactivityThreshold != nil {
|
if wgIface.IsUserspaceBind() {
|
||||||
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
|
||||||
m.inactivityThreshold = *config.InactivityThreshold
|
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
m.connStateListener = &dispatcher.ConnectionListener{
|
|
||||||
OnConnected: m.onPeerConnected,
|
|
||||||
OnDisconnected: m.onPeerDisconnected,
|
|
||||||
}
|
|
||||||
|
|
||||||
connStateDispatcher.AddListener(m.connStateListener)
|
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
@@ -131,24 +116,28 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
|
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
|
||||||
len(m.haGroupToPeers), len(m.peerToHAGroups))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the manager and listens for peer activity and inactivity events
|
// Start starts the manager and listens for peer activity and inactivity events
|
||||||
func (m *Manager) Start(ctx context.Context) {
|
func (m *Manager) Start(ctx context.Context) {
|
||||||
defer m.close()
|
defer m.close()
|
||||||
|
|
||||||
|
if m.inactivityManager != nil {
|
||||||
|
go m.inactivityManager.Start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||||
m.onPeerActivity(ctx, peerConnID)
|
m.onPeerActivity(peerConnID)
|
||||||
case peerConnID := <-m.onInactive:
|
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||||
m.onPeerInactivityTimedOut(ctx, peerConnID)
|
m.onPeerInactivityTimedOut(peerIDs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExcludePeer marks peers for a permanent connection
|
// ExcludePeer marks peers for a permanent connection
|
||||||
@@ -156,7 +145,7 @@ func (m *Manager) Start(ctx context.Context) {
|
|||||||
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||||
// this case, we suppose that the connection status is connected or connecting.
|
// this case, we suppose that the connection status is connected or connecting.
|
||||||
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||||
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
@@ -187,7 +176,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
|
|||||||
|
|
||||||
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||||
|
|
||||||
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
if err := m.addActivePeer(&peerCfg); err != nil {
|
||||||
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -197,7 +186,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
|
|||||||
return added
|
return added
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
|
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
@@ -217,9 +206,6 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
|
||||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
|
||||||
|
|
||||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||||
peerCfg: &peerCfg,
|
peerCfg: &peerCfg,
|
||||||
@@ -229,7 +215,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
|
|||||||
// Check if this peer should be activated because its HA group peers are active
|
// Check if this peer should be activated because its HA group peers are active
|
||||||
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
|
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
|
||||||
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
|
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
|
||||||
m.activateNewPeerInActiveGroup(ctx, peerCfg)
|
m.activateNewPeerInActiveGroup(peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
@@ -237,7 +223,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
|
|||||||
|
|
||||||
// AddActivePeers adds a list of peers to the lazy connection manager
|
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||||
// suppose these peers was in connected or in connecting states
|
// suppose these peers was in connected or in connecting states
|
||||||
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
@@ -247,7 +233,7 @@ func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerCon
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.addActivePeer(ctx, cfg); err != nil {
|
if err := m.addActivePeer(&cfg); err != nil {
|
||||||
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -264,7 +250,7 @@ func (m *Manager) RemovePeer(peerID string) {
|
|||||||
|
|
||||||
// ActivatePeer activates a peer connection when a signal message is received
|
// ActivatePeer activates a peer connection when a signal message is received
|
||||||
// Also activates all peers in the same HA groups as this peer
|
// Also activates all peers in the same HA groups as this peer
|
||||||
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
cfg, mp := m.getPeerForActivation(peerID)
|
cfg, mp := m.getPeerForActivation(peerID)
|
||||||
@@ -272,15 +258,42 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool)
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.activateSinglePeer(ctx, cfg, mp) {
|
if !m.activateSinglePeer(cfg, mp) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
m.activateHAGroupPeers(ctx, peerID)
|
m.activateHAGroupPeers(cfg)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mp.expectedWatcher != watcherInactivity {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("start activity monitor")
|
||||||
|
|
||||||
|
mp.expectedWatcher = watcherActivity
|
||||||
|
|
||||||
|
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
|
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||||
|
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
|
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
|
||||||
// Returns nil values if the peer should be skipped
|
// Returns nil values if the peer should be skipped
|
||||||
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
|
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
|
||||||
@@ -302,41 +315,36 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
|
|||||||
return cfg, mp
|
return cfg, mp
|
||||||
}
|
}
|
||||||
|
|
||||||
// activateSinglePeer activates a single peer (internal method)
|
// activateSinglePeer activates a single peer
|
||||||
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
// return true if the peer was activated, false if it was already active
|
||||||
mp.expectedWatcher = watcherInactivity
|
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
||||||
|
if mp.expectedWatcher == watcherInactivity {
|
||||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
|
||||||
|
|
||||||
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
|
||||||
if !ok {
|
|
||||||
cfg.Log.Errorf("inactivity monitor not found for peer")
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Log.Infof("starting inactivity monitor")
|
mp.expectedWatcher = watcherInactivity
|
||||||
go im.Start(ctx, m.onInactive)
|
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||||
|
m.inactivityManager.AddPeer(cfg)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
|
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
|
||||||
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
|
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
|
||||||
var peersToActivate []string
|
var peersToActivate []string
|
||||||
|
|
||||||
m.routesMu.RLock()
|
m.routesMu.RLock()
|
||||||
haGroups := m.peerToHAGroups[triggerPeerID]
|
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
|
||||||
|
|
||||||
if len(haGroups) == 0 {
|
if len(haGroups) == 0 {
|
||||||
m.routesMu.RUnlock()
|
m.routesMu.RUnlock()
|
||||||
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
|
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, haGroup := range haGroups {
|
for _, haGroup := range haGroups {
|
||||||
peers := m.haGroupToPeers[haGroup]
|
peers := m.haGroupToPeers[haGroup]
|
||||||
for _, peerID := range peers {
|
for _, peerID := range peers {
|
||||||
if peerID != triggerPeerID {
|
if peerID != triggeredPeerCfg.PublicKey {
|
||||||
peersToActivate = append(peersToActivate, peerID)
|
peersToActivate = append(peersToActivate, peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -350,16 +358,16 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.activateSinglePeer(ctx, cfg, mp) {
|
if m.activateSinglePeer(cfg, mp) {
|
||||||
activatedCount++
|
activatedCount++
|
||||||
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
|
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
|
||||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if activatedCount > 0 {
|
if activatedCount > 0 {
|
||||||
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
|
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
|
||||||
activatedCount, triggerPeerID, haGroups)
|
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,13 +402,13 @@ func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
|
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
|
||||||
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
|
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
|
||||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
|
if !m.activateSinglePeer(&peerCfg, mp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -408,23 +416,19 @@ func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazy
|
|||||||
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
|
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
|
||||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||||
peerCfg.Log.Warnf("peer already managed")
|
peerCfg.Log.Warnf("peer already managed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
m.managedPeers[peerCfg.PublicKey] = peerCfg
|
||||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
|
||||||
|
|
||||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
|
||||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||||
peerCfg: &peerCfg,
|
peerCfg: peerCfg,
|
||||||
expectedWatcher: watcherInactivity,
|
expectedWatcher: watcherInactivity,
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
m.inactivityManager.AddPeer(peerCfg)
|
||||||
go im.Start(ctx, m.onInactive)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -436,12 +440,7 @@ func (m *Manager) removePeer(peerID string) {
|
|||||||
|
|
||||||
cfg.Log.Infof("removing lazy peer")
|
cfg.Log.Infof("removing lazy peer")
|
||||||
|
|
||||||
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
m.inactivityManager.RemovePeer(cfg.PublicKey)
|
||||||
im.Stop()
|
|
||||||
delete(m.inactivityMonitors, cfg.PeerConnID)
|
|
||||||
cfg.Log.Debugf("inactivity monitor stopped")
|
|
||||||
}
|
|
||||||
|
|
||||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||||
delete(m.managedPeers, peerID)
|
delete(m.managedPeers, peerID)
|
||||||
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||||
@@ -451,12 +450,8 @@ func (m *Manager) close() {
|
|||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
|
||||||
m.activityManager.Close()
|
m.activityManager.Close()
|
||||||
for _, iw := range m.inactivityMonitors {
|
|
||||||
iw.Stop()
|
|
||||||
}
|
|
||||||
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
|
||||||
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||||
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||||
|
|
||||||
@@ -470,7 +465,7 @@ func (m *Manager) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
|
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
|
||||||
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
|
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
|
||||||
m.routesMu.RLock()
|
m.routesMu.RLock()
|
||||||
defer m.routesMu.RUnlock()
|
defer m.routesMu.RUnlock()
|
||||||
|
|
||||||
@@ -480,9 +475,18 @@ func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, haGroup := range haGroups {
|
for _, haGroup := range haGroups {
|
||||||
groupPeers := m.haGroupToPeers[haGroup]
|
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
|
||||||
|
groupPeers := m.haGroupToPeers[haGroup]
|
||||||
for _, groupPeerID := range groupPeers {
|
for _, groupPeerID := range groupPeers {
|
||||||
|
|
||||||
if groupPeerID == peerID {
|
if groupPeerID == peerID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -501,17 +505,15 @@ func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other member is still connected, defer idle
|
// If any peer in the group is active, do defer idle
|
||||||
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
|
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
@@ -528,100 +530,56 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
|
|||||||
|
|
||||||
mp.peerCfg.Log.Infof("detected peer activity")
|
mp.peerCfg.Log.Infof("detected peer activity")
|
||||||
|
|
||||||
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
|
if !m.activateSinglePeer(mp.peerCfg, mp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
|
m.activateHAGroupPeers(mp.peerCfg)
|
||||||
|
|
||||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
|
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
for peerID := range peerIDs {
|
||||||
|
peerCfg, ok := m.managedPeers[peerID]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("peer not found by id: %v", peerConnID)
|
log.Errorf("peer not found by peerId: %v", peerID)
|
||||||
return
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if mp.expectedWatcher != watcherInactivity {
|
if mp.expectedWatcher != watcherInactivity {
|
||||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
|
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
|
||||||
iw, ok := m.inactivityMonitors[peerConnID]
|
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
|
||||||
if ok {
|
continue
|
||||||
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
|
|
||||||
iw.ResetMonitor(ctx, m.onInactive)
|
|
||||||
} else {
|
|
||||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mp.peerCfg.Log.Infof("connection timed out")
|
mp.peerCfg.Log.Infof("connection timed out")
|
||||||
|
|
||||||
// this is blocking operation, potentially can be optimized
|
// this is blocking operation, potentially can be optimized
|
||||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
mp.peerCfg.Log.Infof("start activity monitor")
|
mp.peerCfg.Log.Infof("start activity monitor")
|
||||||
|
|
||||||
mp.expectedWatcher = watcherActivity
|
mp.expectedWatcher = watcherActivity
|
||||||
|
|
||||||
// just in case free up
|
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||||
m.inactivityMonitors[peerConnID].PauseTimer()
|
|
||||||
|
|
||||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||||
return
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
|
||||||
m.managedPeersMu.Lock()
|
|
||||||
defer m.managedPeersMu.Unlock()
|
|
||||||
|
|
||||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if mp.expectedWatcher != watcherInactivity {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
|
||||||
if !ok {
|
|
||||||
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
|
||||||
iw.PauseTimer()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
|
||||||
m.managedPeersMu.Lock()
|
|
||||||
defer m.managedPeersMu.Unlock()
|
|
||||||
|
|
||||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if mp.expectedWatcher != watcherInactivity {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
|
||||||
iw.ResetTimer()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,4 +11,6 @@ import (
|
|||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
LastActivities() map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
)
|
)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
log.Errorf("failed registering peer %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,6 @@ type Conn struct {
|
|||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
semaphore *semaphoregroup.SemaphoreGroup
|
semaphore *semaphoregroup.SemaphoreGroup
|
||||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// debug purpose
|
// debug purpose
|
||||||
@@ -144,7 +143,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
relayManager: services.RelayManager,
|
relayManager: services.RelayManager,
|
||||||
srWatcher: services.SrWatcher,
|
srWatcher: services.SrWatcher,
|
||||||
semaphore: services.Semaphore,
|
semaphore: services.Semaphore,
|
||||||
peerConnDispatcher: services.PeerConnDispatcher,
|
|
||||||
statusRelay: worker.NewAtomicStatus(),
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||||
@@ -226,7 +224,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||||
func (conn *Conn) Close() {
|
func (conn *Conn) Close(signalToRemote bool) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.wgWatcherWg.Wait()
|
defer conn.wgWatcherWg.Wait()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
@@ -236,6 +234,12 @@ func (conn *Conn) Close() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if signalToRemote {
|
||||||
|
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
|
||||||
|
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
conn.Log.Infof("close peer connection")
|
conn.Log.Infof("close peer connection")
|
||||||
conn.ctxCancel()
|
conn.ctxCancel()
|
||||||
|
|
||||||
@@ -404,15 +408,10 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
oldState := conn.currentConnPriority
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
|
|
||||||
if oldState == conntype.None {
|
|
||||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onICEStateDisconnected() {
|
func (conn *Conn) onICEStateDisconnected() {
|
||||||
@@ -450,7 +449,6 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||||
@@ -530,7 +528,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.Log.Infof("start to communicate with peer via relay")
|
conn.Log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onRelayDisconnected() {
|
func (conn *Conn) onRelayDisconnected() {
|
||||||
@@ -545,11 +542,7 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
|
|
||||||
if conn.currentConnPriority == conntype.Relay {
|
if conn.currentConnPriority == conntype.Relay {
|
||||||
conn.Log.Debugf("clean up WireGuard config")
|
conn.Log.Debugf("clean up WireGuard config")
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
|
||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
|
||||||
}
|
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
|
|||||||
@@ -68,3 +68,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Signaler) SignalIdle(remoteKey string) error {
|
||||||
|
return s.signal.Send(&sProto.Message{
|
||||||
|
Key: s.wgPrivateKey.PublicKey().String(),
|
||||||
|
RemoteKey: remoteKey,
|
||||||
|
Body: &sProto.Body{
|
||||||
|
Type: sProto.Body_GO_IDLE,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -95,6 +95,17 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConnIdle(pubKey string) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.Close(true)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) PeerConnClose(pubKey string) {
|
func (s *Store) PeerConnClose(pubKey string) {
|
||||||
s.peerConnsMu.RLock()
|
s.peerConnsMu.RLock()
|
||||||
defer s.peerConnsMu.RUnlock()
|
defer s.peerConnsMu.RUnlock()
|
||||||
@@ -103,7 +114,7 @@ func (s *Store) PeerConnClose(pubKey string) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.Close()
|
p.Close(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) PeersPubKey() []string {
|
func (s *Store) PeersPubKey() []string {
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
|
|||||||
w.currentChosenStatus = nil
|
w.currentChosenStatus = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandlerFromRoute(
|
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||||
rt *route.Route,
|
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
dnsRouterInteval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
useNewDNSRoute bool,
|
|
||||||
) RouteHandler {
|
|
||||||
switch handlerType(rt, useNewDNSRoute) {
|
|
||||||
case handlerTypeDnsInterceptor:
|
case handlerTypeDnsInterceptor:
|
||||||
return dnsinterceptor.New(
|
return dnsinterceptor.New(params)
|
||||||
rt,
|
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
statusRecorder,
|
|
||||||
dnsServer,
|
|
||||||
wgInterface,
|
|
||||||
peerStore,
|
|
||||||
)
|
|
||||||
case handlerTypeDynamic:
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||||
return dynamic.NewRoute(
|
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||||
rt,
|
return dynamic.NewRoute(params, dnsAddr)
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
dnsRouterInteval,
|
|
||||||
statusRecorder,
|
|
||||||
wgInterface,
|
|
||||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetBestrouteFromStatuses(t *testing.T) {
|
func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
statuses map[route.ID]routerPeerStatus
|
statuses map[route.ID]routerPeerStatus
|
||||||
@@ -811,9 +811,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params := common.HandlerParams{
|
||||||
|
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||||
|
}
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &Watcher{
|
client := &Watcher{
|
||||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
handler: static.NewRoute(params),
|
||||||
routes: tc.existingRoutes,
|
routes: tc.existingRoutes,
|
||||||
currentChosen: currentRoute,
|
currentChosen: currentRoute,
|
||||||
}
|
}
|
||||||
|
|||||||
28
client/internal/routemanager/common/params.go
Normal file
28
client/internal/routemanager/common/params.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerParams struct {
|
||||||
|
Route *route.Route
|
||||||
|
RouteRefCounter *refcounter.RouteRefCounter
|
||||||
|
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
|
DnsRouterInterval time.Duration
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
WgInterface iface.WGIface
|
||||||
|
DnsServer dns.Server
|
||||||
|
PeerStore *peerstore.Store
|
||||||
|
UseNewDNSRoute bool
|
||||||
|
Firewall manager.Manager
|
||||||
|
FakeIPManager *fakeip.Manager
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -12,11 +13,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -24,6 +28,11 @@ import (
|
|||||||
|
|
||||||
type domainMap map[domain.Domain][]netip.Prefix
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type internalDNATer interface {
|
||||||
|
RemoveInternalDNATMapping(netip.Addr) error
|
||||||
|
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
type wgInterface interface {
|
type wgInterface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
@@ -40,26 +49,22 @@ type DnsInterceptor struct {
|
|||||||
interceptedDomains domainMap
|
interceptedDomains domainMap
|
||||||
wgInterface wgInterface
|
wgInterface wgInterface
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
|
firewall firewall.Manager
|
||||||
|
fakeIPManager *fakeip.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(params common.HandlerParams) *DnsInterceptor {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
wgInterface wgInterface,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
) *DnsInterceptor {
|
|
||||||
return &DnsInterceptor{
|
return &DnsInterceptor{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: params.StatusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: params.DnsServer,
|
||||||
wgInterface: wgInterface,
|
wgInterface: params.WgInterface,
|
||||||
|
peerStore: params.PeerStore,
|
||||||
|
firewall: params.Firewall,
|
||||||
|
fakeIPManager: params.FakeIPManager,
|
||||||
interceptedDomains: make(domainMap),
|
interceptedDomains: make(domainMap),
|
||||||
peerStore: peerStore,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
// Routes should use fake IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIPs should use real IPs
|
||||||
if d.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
@@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.cleanupDNATMappings()
|
||||||
|
|
||||||
for _, domain := range d.route.Domains {
|
for _, domain := range d.route.Domains {
|
||||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
}
|
}
|
||||||
@@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||||
|
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||||
|
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||||
|
}
|
||||||
|
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||||
|
// AllowedIPs always use real IPs
|
||||||
|
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
realPrefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||||
|
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||||
|
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||||
|
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||||
|
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||||
|
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedIPs use real IPs
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||||
|
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
@@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
// AllowedIPs use real IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
merr = multierror.Append(merr, err)
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
domain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefixes := range d.interceptedDomains {
|
for _, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
|
// AllowedIPs use real IPs
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
}
|
}
|
||||||
@@ -287,6 +356,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,61 +368,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
// logPrefixChanges handles the logging for prefix changes
|
||||||
d.mu.Lock()
|
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||||
defer d.mu.Unlock()
|
|
||||||
|
|
||||||
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
|
||||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
// Add new prefixes
|
|
||||||
for _, prefix := range toAdd {
|
|
||||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.currentPeerKey == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
|
||||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !d.route.KeepRoute {
|
|
||||||
// Remove old prefixes
|
|
||||||
for _, prefix := range toRemove {
|
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
|
||||||
}
|
|
||||||
if d.currentPeerKey != "" {
|
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update domain prefixes using resolved domain as key
|
|
||||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
|
||||||
if d.route.KeepRoute {
|
|
||||||
// replace stored prefixes with old + added
|
|
||||||
// nolint:gocritic
|
|
||||||
newPrefixes = append(oldPrefixes, toAdd...)
|
|
||||||
}
|
|
||||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
|
||||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
|
||||||
|
|
||||||
if len(toAdd) > 0 {
|
if len(toAdd) > 0 {
|
||||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
resolvedDomain.SafeString(),
|
resolvedDomain.SafeString(),
|
||||||
@@ -364,11 +382,173 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
originalDomain.SafeString(),
|
originalDomain.SafeString(),
|
||||||
toRemove)
|
toRemove)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||||
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
var dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
|
||||||
|
// Handle DNAT mappings for new prefixes
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||||
|
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||||
|
dnatMappings[fakeIP] = realIP
|
||||||
|
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new prefixes
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.addDNATMappings(dnatMappings)
|
||||||
|
|
||||||
|
if !d.route.KeepRoute {
|
||||||
|
// Remove old prefixes
|
||||||
|
for _, prefix := range toRemove {
|
||||||
|
// Routes use fake IPs
|
||||||
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||||
|
}
|
||||||
|
// AllowedIPs use real IPs
|
||||||
|
if err := d.removeAllowedIP(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.removeDNATMappings(toRemove)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update domain prefixes using resolved domain as key - store real IPs
|
||||||
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
|
if d.route.KeepRoute {
|
||||||
|
// nolint:gocritic
|
||||||
|
newPrefixes = append(oldPrefixes, toAdd...)
|
||||||
|
}
|
||||||
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
|
||||||
|
// Store real IPs for status (user-facing), not fake IPs
|
||||||
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||||
|
|
||||||
|
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||||
|
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||||
|
if len(realPrefixes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range realPrefixes {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||||
|
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalDnatFw checks if the firewall supports internal DNAT
|
||||||
|
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||||
|
if d.firewall == nil || runtime.GOOS != "android" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
fw, ok := d.firewall.(internalDNATer)
|
||||||
|
return fw, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// addDNATMappings adds DNAT mappings to the firewall
|
||||||
|
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for fakeIP, realIP := range mappings {
|
||||||
|
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||||
|
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||||
|
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
d.removeDNATMappings(prefixes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||||
|
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace A and AAAA records with fake IPs
|
||||||
|
for _, answer := range reply.Answer {
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.A = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
case *dns.AAAA:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.AAAA = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
prefixSet := make(map[netip.Prefix]bool)
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
for _, prefix := range oldPrefixes {
|
for _, prefix := range oldPrefixes {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
@@ -52,24 +53,16 @@ type Route struct {
|
|||||||
resolverAddr string
|
resolverAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(
|
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
interval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
resolverAddr string,
|
|
||||||
) *Route {
|
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
interval: interval,
|
interval: params.DnsRouterInterval,
|
||||||
dynamicDomains: domainMap{},
|
statusRecorder: params.StatusRecorder,
|
||||||
statusRecorder: statusRecorder,
|
wgInterface: params.WgInterface,
|
||||||
wgInterface: wgInterface,
|
|
||||||
resolverAddr: resolverAddr,
|
resolverAddr: resolverAddr,
|
||||||
|
dynamicDomains: domainMap{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
nextIP netip.Addr // Next IP to allocate
|
||||||
|
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||||
|
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||||
|
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||||
|
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||||
|
func NewManager() *Manager {
|
||||||
|
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||||
|
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
nextIP: baseIP,
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: baseIP,
|
||||||
|
maxIP: maxIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||||
|
// Returns the fake IP, or existing fake IP if already allocated
|
||||||
|
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||||
|
if !realIP.Is4() {
|
||||||
|
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if fakeIP, exists := m.allocated[realIP]; exists {
|
||||||
|
return fakeIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startIP := m.nextIP
|
||||||
|
for {
|
||||||
|
currentIP := m.nextIP
|
||||||
|
|
||||||
|
// Advance to next IP, wrapping at boundary
|
||||||
|
if m.nextIP.Compare(m.maxIP) >= 0 {
|
||||||
|
m.nextIP = m.baseIP
|
||||||
|
} else {
|
||||||
|
m.nextIP = m.nextIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if current IP is available
|
||||||
|
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
||||||
|
m.allocated[realIP] = currentIP
|
||||||
|
m.fakeToReal[currentIP] = realIP
|
||||||
|
return currentIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent infinite loop if all IPs exhausted
|
||||||
|
if m.nextIP.Compare(startIP) == 0 {
|
||||||
|
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||||
|
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
fakeIP, exists := m.allocated[realIP]
|
||||||
|
return fakeIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||||
|
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
realIP, exists := m.fakeToReal[fakeIP]
|
||||||
|
return realIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||||
|
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||||
|
return netip.MustParsePrefix("240.0.0.0/8")
|
||||||
|
}
|
||||||
240
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
240
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
if manager.baseIP.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.maxIP.String() != "240.255.255.254" {
|
||||||
|
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||||
|
t.Errorf("Expected nextIP to start at baseIP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("8.8.8.8")
|
||||||
|
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fakeIP.Is4() {
|
||||||
|
t.Error("Fake IP should be IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check it's in the correct range
|
||||||
|
if fakeIP.As4()[0] != 240 {
|
||||||
|
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same fake IP for same real IP
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(fakeIP2) != 0 {
|
||||||
|
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||||
|
|
||||||
|
_, err := manager.AllocateFakeIP(realIPv6)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for IPv6 address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
|
// Should not exist initially
|
||||||
|
_, exists := manager.GetFakeIP(realIP)
|
||||||
|
if exists {
|
||||||
|
t.Error("Fake IP should not exist before allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate and check
|
||||||
|
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Fake IP should exist after allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||||
|
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleAllocations(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
allocations := make(map[netip.Addr]netip.Addr)
|
||||||
|
|
||||||
|
// Allocate multiple IPs
|
||||||
|
for i := 1; i <= 100; i++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
for _, existingFake := range allocations {
|
||||||
|
if fakeIP.Compare(existingFake) == 0 {
|
||||||
|
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allocations[realIP] = fakeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all allocations can be retrieved
|
||||||
|
for realIP, expectedFake := range allocations {
|
||||||
|
actualFake, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("Missing allocation for %s", realIP.String())
|
||||||
|
}
|
||||||
|
if actualFake.Compare(expectedFake) != 0 {
|
||||||
|
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIPBlock(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
block := manager.GetFakeIPBlock()
|
||||||
|
|
||||||
|
expected := "240.0.0.0/8"
|
||||||
|
if block.String() != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
const numGoroutines = 50
|
||||||
|
const allocationsPerGoroutine = 10
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||||
|
|
||||||
|
// Concurrent allocations
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results <- fakeIP
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
seen := make(map[netip.Addr]bool)
|
||||||
|
count := 0
|
||||||
|
for fakeIP := range results {
|
||||||
|
if seen[fakeIP] {
|
||||||
|
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
seen[fakeIP] = true
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != numGoroutines*allocationsPerGoroutine {
|
||||||
|
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPExhaustion(t *testing.T) {
|
||||||
|
// Create a manager with limited range for testing
|
||||||
|
manager := &Manager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate all available IPs
|
||||||
|
realIPs := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.0.0.1"),
|
||||||
|
netip.MustParseAddr("1.0.0.2"),
|
||||||
|
netip.MustParseAddr("1.0.0.3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, realIP := range realIPs {
|
||||||
|
_, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to allocate one more - should fail
|
||||||
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected exhaustion error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapAround(t *testing.T) {
|
||||||
|
// Create manager starting near the end of range
|
||||||
|
manager := &Manager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the last IP
|
||||||
|
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP1.String() != "240.0.0.254" {
|
||||||
|
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next allocation should wrap around to the beginning
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP2.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,9 +8,11 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@@ -24,6 +26,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
@@ -49,7 +53,7 @@ type Manager interface {
|
|||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
SetFirewall(firewall.Manager) error
|
||||||
Stop(stateManager *statemanager.Manager)
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,6 +67,7 @@ type ManagerConfig struct {
|
|||||||
InitialRoutes []*route.Route
|
InitialRoutes []*route.Route
|
||||||
StateManager *statemanager.Manager
|
StateManager *statemanager.Manager
|
||||||
DNSServer dns.Server
|
DNSServer dns.Server
|
||||||
|
DNSFeatureFlag bool
|
||||||
PeerStore *peerstore.Store
|
PeerStore *peerstore.Store
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -89,11 +94,13 @@ type DefaultManager struct {
|
|||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes route.HAMap
|
clientRoutes route.HAMap
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
firewall firewall.Manager
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
useNewDNSRoute bool
|
useNewDNSRoute bool
|
||||||
disableClientRoutes bool
|
disableClientRoutes bool
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||||
|
fakeIPManager *fakeip.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
@@ -129,11 +136,31 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
dm.setupAndroidRoutes(config)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
|
||||||
}
|
}
|
||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||||
|
cr := m.initialClientRoutes(config.InitialRoutes)
|
||||||
|
|
||||||
|
routesForComparison := slices.Clone(cr)
|
||||||
|
|
||||||
|
if config.DNSFeatureFlag {
|
||||||
|
m.fakeIPManager = fakeip.NewManager()
|
||||||
|
|
||||||
|
id := uuid.NewString()
|
||||||
|
fakeIPRoute := &route.Route{
|
||||||
|
ID: route.ID(id),
|
||||||
|
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||||
|
NetID: route.NetID(id),
|
||||||
|
Peer: m.pubKey,
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
}
|
||||||
|
cr = append(cr, fakeIPRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
@@ -222,16 +249,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
|||||||
return routeselector.NewRouteSelector()
|
return routeselector.NewRouteSelector()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
// SetFirewall sets the firewall manager for the DefaultManager
|
||||||
if m.disableServerRoutes {
|
// Not thread-safe, should be called before starting the manager
|
||||||
|
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||||
|
m.firewall = firewall
|
||||||
|
|
||||||
|
if m.disableServerRoutes || firewall == nil {
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if firewall == nil {
|
|
||||||
return errors.New("firewall manager is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -299,17 +326,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for id, route := range toAdd {
|
for id, route := range toAdd {
|
||||||
handler := client.HandlerFromRoute(
|
params := common.HandlerParams{
|
||||||
route,
|
Route: route,
|
||||||
m.routeRefCounter,
|
RouteRefCounter: m.routeRefCounter,
|
||||||
m.allowedIPsRefCounter,
|
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||||
m.dnsRouteInterval,
|
DnsRouterInterval: m.dnsRouteInterval,
|
||||||
m.statusRecorder,
|
StatusRecorder: m.statusRecorder,
|
||||||
m.wgInterface,
|
WgInterface: m.wgInterface,
|
||||||
m.dnsServer,
|
DnsServer: m.dnsServer,
|
||||||
m.peerStore,
|
PeerStore: m.peerStore,
|
||||||
m.useNewDNSRoute,
|
UseNewDNSRoute: m.useNewDNSRoute,
|
||||||
)
|
Firewall: m.firewall,
|
||||||
|
FakeIPManager: m.fakeIPManager,
|
||||||
|
}
|
||||||
|
handler := client.HandlerFromRoute(params)
|
||||||
if err := handler.AddRoute(m.ctx); err != nil {
|
if err := handler.AddRoute(m.ctx); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||||
continue
|
continue
|
||||||
@@ -517,6 +547,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
|||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
rs = append(rs, routes...)
|
rs = append(rs, routes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,124 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Notifier struct {
|
|
||||||
initialRouteRanges []string
|
|
||||||
routeRanges []string
|
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
|
||||||
listenerMux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
|
||||||
return &Notifier{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
n.listener = listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
|
||||||
nets := make([]string, 0)
|
|
||||||
for _, r := range clientRoutes {
|
|
||||||
if r.IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nets = append(nets, r.Network.String())
|
|
||||||
}
|
|
||||||
sort.Strings(nets)
|
|
||||||
n.initialRouteRanges = nets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var newNets []string
|
|
||||||
for _, routes := range idMap {
|
|
||||||
for _, r := range routes {
|
|
||||||
if r.IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newNets = append(newNets, r.Network.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(newNets)
|
|
||||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnNewPrefixes is called from iOS only
|
|
||||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
|
||||||
newNets := make([]string, 0)
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
newNets = append(newNets, prefix.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(newNets)
|
|
||||||
if !n.hasDiff(n.routeRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) notify() {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
if n.listener == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(l listener.NetworkChangeListener) {
|
|
||||||
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
|
||||||
}(n.listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) hasDiff(a []string, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for i, v := range a {
|
|
||||||
if v != b[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
|
||||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
|
||||||
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
|
||||||
ranges := inputRanges
|
|
||||||
for _, r := range inputRanges {
|
|
||||||
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
|
||||||
// to ensure that all traffic is managed by the tunnel interface on android
|
|
||||||
if r == "0.0.0.0/0" {
|
|
||||||
ranges = append(ranges, "::/0")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ranges
|
|
||||||
}
|
|
||||||
127
client/internal/routemanager/notifier/notifier_android.go
Normal file
127
client/internal/routemanager/notifier/notifier_android.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct {
|
||||||
|
initialRoutes []*route.Route
|
||||||
|
currentRoutes []*route.Route
|
||||||
|
|
||||||
|
listener listener.NetworkChangeListener
|
||||||
|
listenerMux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
n.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||||
|
// initialRoutes contains fake IP block for interface configuration
|
||||||
|
filteredInitial := make([]*route.Route, 0)
|
||||||
|
for _, r := range initialRoutes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredInitial = append(filteredInitial, r)
|
||||||
|
}
|
||||||
|
n.initialRoutes = filteredInitial
|
||||||
|
|
||||||
|
// routesForComparison excludes fake IP block for comparison with new routes
|
||||||
|
filteredComparison := make([]*route.Route, 0)
|
||||||
|
for _, r := range routesForComparison {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredComparison = append(filteredComparison, r)
|
||||||
|
}
|
||||||
|
n.currentRoutes = filteredComparison
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
|
var newRoutes []*route.Route
|
||||||
|
for _, routes := range idMap {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newRoutes = append(newRoutes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.currentRoutes = newRoutes
|
||||||
|
n.notify()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
|
||||||
|
// Not used on Android
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) notify() {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||||
|
sort.Strings(routeStrings)
|
||||||
|
go func(l listener.NetworkChangeListener) {
|
||||||
|
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||||
|
}(n.listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||||
|
nets := make([]string, 0, len(routes))
|
||||||
|
for _, r := range routes {
|
||||||
|
nets = append(nets, r.NetString())
|
||||||
|
}
|
||||||
|
return nets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
|
||||||
|
slices.SortFunc(a, func(x, y *route.Route) int {
|
||||||
|
return strings.Compare(x.NetString(), y.NetString())
|
||||||
|
})
|
||||||
|
slices.SortFunc(b, func(x, y *route.Route) int {
|
||||||
|
return strings.Compare(x.NetString(), y.NetString())
|
||||||
|
})
|
||||||
|
|
||||||
|
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
|
||||||
|
return x.NetString() == y.NetString()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
initialStrings := n.routesToStrings(n.initialRoutes)
|
||||||
|
sort.Strings(initialStrings)
|
||||||
|
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
|
||||||
|
return append(slices.Clone(inputRanges), "::/0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputRanges
|
||||||
|
}
|
||||||
80
client/internal/routemanager/notifier/notifier_ios.go
Normal file
80
client/internal/routemanager/notifier/notifier_ios.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct {
|
||||||
|
currentPrefixes []string
|
||||||
|
|
||||||
|
listener listener.NetworkChangeListener
|
||||||
|
listenerMux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
n.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||||
|
// iOS doesn't care about initial routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||||
|
// Not used on iOS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||||
|
newNets := make([]string, 0)
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
newNets = append(newNets, prefix.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(newNets)
|
||||||
|
|
||||||
|
if slices.Equal(n.currentPrefixes, newNets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.currentPrefixes = newNets
|
||||||
|
n.notify()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) notify() {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(l listener.NetworkChangeListener) {
|
||||||
|
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
|
||||||
|
}(n.listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||||
|
for _, r := range inputRanges {
|
||||||
|
if r == "0.0.0.0/0" {
|
||||||
|
return append(slices.Clone(inputRanges), "::/0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputRanges
|
||||||
|
}
|
||||||
36
client/internal/routemanager/notifier/notifier_other.go
Normal file
36
client/internal/routemanager/notifier/notifier_other.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct{}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -16,11 +17,11 @@ type Route struct {
|
|||||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
func NewRoute(params common.HandlerParams) *Route {
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
@@ -52,6 +53,9 @@ type SysOps struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// notifier is used to notify the system of route changes (also used on mobile)
|
// notifier is used to notify the system of route changes (also used on mobile)
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
|
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||||
|
//nolint:unused // only used on BSD systems
|
||||||
|
seq atomic.Uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
@@ -61,6 +65,11 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:unused // only used on BSD systems
|
||||||
|
func (r *SysOps) getSeq() int {
|
||||||
|
return int(r.seq.Add(1))
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||||
addr := prefix.Addr()
|
addr := prefix.Addr()
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
|||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP,
|
Flags: unix.RTF_UP,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Seq: 1,
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
|
|
||||||
const numAddrs = unix.RTAX_NETMASK + 1
|
const numAddrs = unix.RTAX_NETMASK + 1
|
||||||
|
|||||||
5
go.mod
5
go.mod
@@ -110,7 +110,7 @@ require (
|
|||||||
gorm.io/driver/mysql v1.5.7
|
gorm.io/driver/mysql v1.5.7
|
||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
gorm.io/driver/sqlite v1.5.7
|
gorm.io/driver/sqlite v1.5.7
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.30.0
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ require (
|
|||||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||||
@@ -247,6 +247,7 @@ require (
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
|
gorm.io/datatypes v1.2.6 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -395,6 +395,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
|||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||||
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||||
@@ -1195,6 +1197,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
|
|||||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/datatypes v1.2.6 h1:KafLdXvFUhzNeL2ncm03Gl3eTLONQfNKZ+wJ+9Y4Nck=
|
||||||
|
gorm.io/datatypes v1.2.6/go.mod h1:M2iO+6S3hhi4nAyYe444Pcb0dcIiOMJ7QHaUXxyiNZY=
|
||||||
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
|
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
|
||||||
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||||
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
|
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
|
||||||
@@ -1204,6 +1208,8 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
|
|||||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
|
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||||
|
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||||
|
|||||||
@@ -106,6 +106,18 @@ type DefaultAccountManager struct {
|
|||||||
disableDefaultPolicy bool
|
disableDefaultPolicy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUniqueConstraintError(err error) bool {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||||
|
strings.Contains(err.Error(), "Error 1062 (23000)"),
|
||||||
|
strings.Contains(err.Error(), "UNIQUE constraint failed"):
|
||||||
|
return true
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||||
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||||
// newly groups to create and an error if any occurred.
|
// newly groups to create and an error if any occurred.
|
||||||
@@ -1192,6 +1204,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
|
|||||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
|
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||||
|
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||||
|
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||||
|
log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if onboarding == nil {
|
||||||
|
onboarding = &types.AccountOnboarding{
|
||||||
|
AccountID: accountID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return onboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||||
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
return nil, status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||||
|
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||||
|
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldOnboarding == nil {
|
||||||
|
oldOnboarding = &types.AccountOnboarding{
|
||||||
|
AccountID: accountID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newOnboarding == nil {
|
||||||
|
return oldOnboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldOnboarding.IsEqual(*newOnboarding) {
|
||||||
|
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
|
||||||
|
return oldOnboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newOnboarding.AccountID = accountID
|
||||||
|
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newOnboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||||
if userAuth.UserId == "" {
|
if userAuth.UserId == "" {
|
||||||
return "", "", errors.New(emptyUserID)
|
return "", "", errors.New(emptyUserID)
|
||||||
@@ -1661,25 +1738,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
|
|
||||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
labelMap := ConvertSliceToMap(existingLabels)
|
|
||||||
newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newLabel == "" {
|
|
||||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newLabel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1733,6 +1791,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
|||||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||||
RoutingPeerDNSResolutionEnabled: true,
|
RoutingPeerDNSResolutionEnabled: true,
|
||||||
},
|
},
|
||||||
|
Onboarding: types.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: true,
|
||||||
|
SignupFormPending: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type Manager interface {
|
|||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||||
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
|
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
|
||||||
|
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||||
@@ -89,6 +90,7 @@ type Manager interface {
|
|||||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||||
|
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||||
|
|||||||
@@ -1208,6 +1208,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
|||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
|
// Ensure that we do not receive an update message before the policy is deleted
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
select {
|
||||||
|
case <-updMsg:
|
||||||
|
t.Logf("received addPeer update message before policy deletion")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -2615,11 +2623,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
account := &types.Account{
|
account := &types.Account{
|
||||||
Id: "accountID",
|
Id: "accountID",
|
||||||
Peers: map[string]*nbpeer.Peer{
|
Peers: map[string]*nbpeer.Peer{
|
||||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
|
||||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
|
||||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
|
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
|
||||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
|
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
|
||||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
|
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
|
||||||
},
|
},
|
||||||
Groups: map[string]*types.Group{
|
Groups: map[string]*types.Group{
|
||||||
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
||||||
@@ -3139,11 +3147,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
|||||||
minMsPerOpCICD float64
|
minMsPerOpCICD float64
|
||||||
maxMsPerOpCICD float64
|
maxMsPerOpCICD float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5, 7, 20, 10, 80},
|
{"Small", 50, 5, 7, 20, 5, 80},
|
||||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
{"Small single", 50, 10, 7, 20, 6, 80},
|
||||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
{"Medium single", 500, 10, 5, 40, 15, 85},
|
||||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3335,11 +3343,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
|
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
|
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -3440,3 +3448,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||||
|
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, onboarding)
|
||||||
|
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||||
|
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||||
|
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||||
|
if onboarding.UpdatedAt.IsZero() {
|
||||||
|
t.Errorf("Onboarding was not retrieved from the store")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
|
||||||
|
account.Id = "with-zero-onboarding"
|
||||||
|
account.Onboarding = types.AccountOnboarding{}
|
||||||
|
err = manager.Store.SaveAccount(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, onboarding)
|
||||||
|
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
|
||||||
|
require.Error(t, err, "should return error when onboarding is not set")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
onboarding := &types.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: true,
|
||||||
|
SignupFormPending: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("update onboarding with no change", func(t *testing.T) {
|
||||||
|
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||||
|
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||||
|
if updated.UpdatedAt.IsZero() {
|
||||||
|
t.Errorf("Onboarding was updated in the store")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
onboarding.OnboardingFlowPending = false
|
||||||
|
onboarding.SignupFormPending = false
|
||||||
|
|
||||||
|
t.Run("update onboarding", func(t *testing.T) {
|
||||||
|
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated)
|
||||||
|
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||||
|
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update onboarding with no onboarding", func(t *testing.T) {
|
||||||
|
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ components:
|
|||||||
description: Account creator
|
description: Account creator
|
||||||
type: string
|
type: string
|
||||||
example: google-oauth2|277474792786460067937
|
example: google-oauth2|277474792786460067937
|
||||||
|
onboarding:
|
||||||
|
$ref: '#/components/schemas/AccountOnboarding'
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- settings
|
- settings
|
||||||
@@ -67,6 +69,21 @@ components:
|
|||||||
- domain_category
|
- domain_category
|
||||||
- created_at
|
- created_at
|
||||||
- created_by
|
- created_by
|
||||||
|
- onboarding
|
||||||
|
AccountOnboarding:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
signup_form_pending:
|
||||||
|
description: Indicates whether the account signup form is pending
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
onboarding_flow_pending:
|
||||||
|
description: Indicates whether the account onboarding flow is pending
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
required:
|
||||||
|
- signup_form_pending
|
||||||
|
- onboarding_flow_pending
|
||||||
AccountSettings:
|
AccountSettings:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -153,6 +170,8 @@ components:
|
|||||||
properties:
|
properties:
|
||||||
settings:
|
settings:
|
||||||
$ref: '#/components/schemas/AccountSettings'
|
$ref: '#/components/schemas/AccountSettings'
|
||||||
|
onboarding:
|
||||||
|
$ref: '#/components/schemas/AccountOnboarding'
|
||||||
required:
|
required:
|
||||||
- settings
|
- settings
|
||||||
User:
|
User:
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ type Account struct {
|
|||||||
|
|
||||||
// Id Account ID
|
// Id Account ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
Onboarding AccountOnboarding `json:"onboarding"`
|
||||||
Settings AccountSettings `json:"settings"`
|
Settings AccountSettings `json:"settings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,8 +267,18 @@ type AccountExtraSettings struct {
|
|||||||
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
|
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountOnboarding defines model for AccountOnboarding.
|
||||||
|
type AccountOnboarding struct {
|
||||||
|
// OnboardingFlowPending Indicates whether the account onboarding flow is pending
|
||||||
|
OnboardingFlowPending bool `json:"onboarding_flow_pending"`
|
||||||
|
|
||||||
|
// SignupFormPending Indicates whether the account signup form is pending
|
||||||
|
SignupFormPending bool `json:"signup_form_pending"`
|
||||||
|
}
|
||||||
|
|
||||||
// AccountRequest defines model for AccountRequest.
|
// AccountRequest defines model for AccountRequest.
|
||||||
type AccountRequest struct {
|
type AccountRequest struct {
|
||||||
|
Onboarding *AccountOnboarding `json:"onboarding,omitempty"`
|
||||||
Settings AccountSettings `json:"settings"`
|
Settings AccountSettings `json:"settings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(accountID, settings, meta)
|
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var onboarding *types.AccountOnboarding
|
||||||
|
if req.Onboarding != nil {
|
||||||
|
onboarding = &types.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: req.Onboarding.OnboardingFlowPending,
|
||||||
|
SignupFormPending: req.Onboarding.SignupFormPending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
@@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(accountID, updatedSettings, meta)
|
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
@@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account {
|
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||||
jwtAllowGroups := settings.JWTAllowGroups
|
jwtAllowGroups := settings.JWTAllowGroups
|
||||||
if jwtAllowGroups == nil {
|
if jwtAllowGroups == nil {
|
||||||
jwtAllowGroups = []string{}
|
jwtAllowGroups = []string{}
|
||||||
@@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
DnsDomain: &settings.DNSDomain,
|
DnsDomain: &settings.DNSDomain,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiOnboarding := api.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: onboarding.OnboardingFlowPending,
|
||||||
|
SignupFormPending: onboarding.SignupFormPending,
|
||||||
|
}
|
||||||
|
|
||||||
if settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
apiSettings.Extra = &api.AccountExtraSettings{
|
apiSettings.Extra = &api.AccountExtraSettings{
|
||||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||||
@@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
CreatedBy: meta.CreatedBy,
|
CreatedBy: meta.CreatedBy,
|
||||||
Domain: meta.Domain,
|
Domain: meta.Domain,
|
||||||
DomainCategory: meta.DomainCategory,
|
DomainCategory: meta.DomainCategory,
|
||||||
|
Onboarding: apiOnboarding,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
|||||||
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||||
return account.GetMeta(), nil
|
return account.GetMeta(), nil
|
||||||
},
|
},
|
||||||
|
GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||||
|
return &types.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: true,
|
||||||
|
SignupFormPending: true,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||||
|
return &types.AccountOnboarding{
|
||||||
|
OnboardingFlowPending: true,
|
||||||
|
SignupFormPending: true,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
settingsManager: settingsMockManager,
|
settingsManager: settingsMockManager,
|
||||||
}
|
}
|
||||||
@@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/accounts/" + accountID,
|
requestPath: "/api/accounts/" + accountID,
|
||||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"),
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: 15552000,
|
PeerLoginExpiration: 15552000,
|
||||||
@@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/accounts/" + accountID,
|
requestPath: "/api/accounts/" + accountID,
|
||||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: 15552000,
|
PeerLoginExpiration: 15552000,
|
||||||
@@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/accounts/" + accountID,
|
requestPath: "/api/accounts/" + accountID,
|
||||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"),
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedSettings: api.AccountSettings{
|
expectedSettings: api.AccountSettings{
|
||||||
PeerLoginExpiration: 554400,
|
PeerLoginExpiration: 554400,
|
||||||
@@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "PutAccount OK without onboarding",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/accounts/" + accountID,
|
||||||
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedSettings: api.AccountSettings{
|
||||||
|
PeerLoginExpiration: 15552000,
|
||||||
|
PeerLoginExpirationEnabled: false,
|
||||||
|
GroupsPropagationEnabled: br(false),
|
||||||
|
JwtGroupsClaimName: sr("roles"),
|
||||||
|
JwtGroupsEnabled: br(true),
|
||||||
|
JwtAllowGroups: &[]string{"test"},
|
||||||
|
RegularUsersViewBlocked: true,
|
||||||
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
|
LazyConnectionEnabled: br(false),
|
||||||
|
DnsDomain: sr(""),
|
||||||
|
},
|
||||||
|
expectedArray: false,
|
||||||
|
expectedID: accountID,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Update account failure with high peer_login_expiration more than 180 days",
|
name: "Update account failure with high peer_login_expiration more than 180 days",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/accounts/" + accountID,
|
requestPath: "/api/accounts/" + accountID,
|
||||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"),
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
},
|
},
|
||||||
@@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/accounts/" + accountID,
|
requestPath: "/api/accounts/" + accountID,
|
||||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"),
|
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -373,3 +373,42 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
|
|||||||
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
||||||
|
var model T
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
if err := stmt.Parse(&model); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse model schema: %w", err)
|
||||||
|
}
|
||||||
|
tableName := stmt.Schema.Table
|
||||||
|
dialect := db.Dialector.Name()
|
||||||
|
|
||||||
|
var columnClause string
|
||||||
|
if dialect == "mysql" {
|
||||||
|
var withLength []string
|
||||||
|
for _, col := range columns {
|
||||||
|
if col == "ip" || col == "dns_label" {
|
||||||
|
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
|
||||||
|
} else {
|
||||||
|
withLength = append(withLength, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columnClause = strings.Join(withLength, ", ")
|
||||||
|
} else {
|
||||||
|
columnClause = strings.Join(columns, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
|
||||||
|
if dialect == "postgres" || dialect == "sqlite" {
|
||||||
|
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
|
||||||
|
if err := db.Exec(createStmt).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -117,7 +117,8 @@ type MockAccountManager struct {
|
|||||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||||
|
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||||
|
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -814,6 +815,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||||
|
if am.GetAccountOnboardingFunc != nil {
|
||||||
|
return am.GetAccountOnboardingFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||||
|
if am.UpdateAccountOnboardingFunc != nil {
|
||||||
|
return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// GetUserByID mocks GetUserByID of the AccountManager interface
|
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||||
if am.GetUserByIDFunc != nil {
|
if am.GetUserByIDFunc != nil {
|
||||||
|
|||||||
@@ -15,13 +15,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -32,6 +33,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Declare sqlStore and ok at the top so they are in scope for all usages
|
||||||
|
var sqlStore *store.SqlStore
|
||||||
|
var ok bool
|
||||||
|
|
||||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||||
// the current user is not an admin.
|
// the current user is not an admin.
|
||||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||||
@@ -234,14 +239,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peer.Name != update.Name {
|
if peer.Name != update.Name {
|
||||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
var newLabel string
|
||||||
|
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
@@ -410,6 +411,24 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
|||||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to serve precomputed network map from DB if up-to-date
|
||||||
|
sqlStore, ok = am.Store.(*store.SqlStore)
|
||||||
|
if ok {
|
||||||
|
db := sqlStore.GetDB()
|
||||||
|
var record *types.NetworkMapRecord
|
||||||
|
var err error
|
||||||
|
record, err = types.GetNetworkMapRecord(db, peer.ID)
|
||||||
|
if err == nil && record.Serial == account.Network.CurrentSerial() {
|
||||||
|
var nm *types.NetworkMap
|
||||||
|
nm, err = types.DeserializeNetworkMap(record.MapJSON)
|
||||||
|
if err == nil {
|
||||||
|
log.WithContext(ctx).Debugf("serving precomputed network map for peer %s from DB", peer.ID)
|
||||||
|
return nm, nil
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Warnf("failed to deserialize precomputed network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
groups := make(map[string][]string)
|
groups := make(map[string][]string)
|
||||||
for groupID, group := range account.Groups {
|
for groupID, group := range account.Groups {
|
||||||
groups[groupID] = group.Peers
|
groups[groupID] = group.Peers
|
||||||
@@ -427,13 +446,34 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
var proxyNetworkMap *types.NetworkMap
|
||||||
|
networkMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok = proxyNetworkMaps[peerID]
|
||||||
if ok {
|
if ok {
|
||||||
networkMap.Merge(proxyNetworkMap)
|
networkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After generating the network map, store it as a precomputed blob in the DB
|
||||||
|
sqlStore, ok = am.Store.(*store.SqlStore)
|
||||||
|
if ok {
|
||||||
|
db := sqlStore.GetDB()
|
||||||
|
data, err := types.SerializeNetworkMap(networkMap)
|
||||||
|
if err == nil {
|
||||||
|
record := &types.NetworkMapRecord{
|
||||||
|
PeerID: peer.ID,
|
||||||
|
AccountID: account.Id,
|
||||||
|
MapJSON: data,
|
||||||
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
err = types.SaveNetworkMapRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to store precomputed network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Warnf("failed to serialize network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return networkMap, nil
|
return networkMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -463,67 +503,50 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
upperKey := strings.ToUpper(setupKey)
|
upperKey := strings.ToUpper(setupKey)
|
||||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
var accountID string
|
addedByUser := len(userID) > 0
|
||||||
var err error
|
|
||||||
addedByUser := false
|
|
||||||
if len(userID) > 0 {
|
|
||||||
addedByUser = true
|
|
||||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
|
||||||
} else {
|
|
||||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer func() {
|
|
||||||
if unlock != nil {
|
|
||||||
unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||||
// and the peer disconnects with a timeout and tries to register again.
|
// and the peer disconnects with a timeout and tries to register again.
|
||||||
// We just check if this machine has been registered before and reject the second registration.
|
// We just check if this machine has been registered before and reject the second registration.
|
||||||
// The connecting peer should be able to recover with a retry.
|
// The connecting peer should be able to recover with a retry.
|
||||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
|
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent := &activity.Event{
|
opEvent := &activity.Event{
|
||||||
Timestamp: time.Now().UTC(),
|
Timestamp: time.Now().UTC(),
|
||||||
AccountID: accountID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
||||||
var setupKeyID string
|
var setupKeyID string
|
||||||
var setupKeyName string
|
var setupKeyName string
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
var groupsToAdd []string
|
var groupsToAdd []string
|
||||||
var allowExtraDNSLabels bool
|
var allowExtraDNSLabels bool
|
||||||
|
var accountID string
|
||||||
if addedByUser {
|
if addedByUser {
|
||||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user groups: %w", err)
|
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||||
}
|
}
|
||||||
groupsToAdd = user.AutoGroups
|
groupsToAdd = user.AutoGroups
|
||||||
opEvent.InitiatorID = userID
|
opEvent.InitiatorID = userID
|
||||||
opEvent.Activity = activity.PeerAddedByUser
|
opEvent.Activity = activity.PeerAddedByUser
|
||||||
|
accountID = user.AccountID
|
||||||
} else {
|
} else {
|
||||||
// Validate the setup key
|
// Validate the setup key
|
||||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get setup key: %w", err)
|
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we will check key twice for early return
|
||||||
if !sk.IsValid() {
|
if !sk.IsValid() {
|
||||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
return nil, nil, nil, status.Errorf(status.NotFound, "couldn't add peer: setup key is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
opEvent.InitiatorID = sk.Id
|
opEvent.InitiatorID = sk.Id
|
||||||
@@ -533,11 +556,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
setupKeyID = sk.Id
|
setupKeyID = sk.Id
|
||||||
setupKeyName = sk.Name
|
setupKeyName = sk.Name
|
||||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||||
|
accountID = sk.AccountID
|
||||||
|
|
||||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
opEvent.AccountID = accountID
|
||||||
|
|
||||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||||
if am.idpManager != nil {
|
if am.idpManager != nil {
|
||||||
@@ -548,18 +573,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get free IP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
registrationTime := time.Now().UTC()
|
||||||
@@ -567,10 +582,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Key: peer.Key,
|
Key: peer.Key,
|
||||||
IP: freeIP,
|
|
||||||
Meta: peer.Meta,
|
Meta: peer.Meta,
|
||||||
Name: peer.Meta.Hostname,
|
Name: peer.Meta.Hostname,
|
||||||
DNSLabel: freeLabel,
|
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
@@ -584,15 +597,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||||
}
|
}
|
||||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get account settings: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
opEvent.TargetID = newPeer.ID
|
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
|
||||||
if !addedByUser {
|
|
||||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||||
@@ -608,6 +615,41 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
|
|
||||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||||
|
|
||||||
|
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 10
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
var freeIP net.IP
|
||||||
|
freeIP, err = types.AllocateRandomPeerIP(network.Net)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var freeLabel string
|
||||||
|
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeer.DNSLabel = freeLabel
|
||||||
|
newPeer.IP = freeIP
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer func() {
|
||||||
|
if unlock != nil {
|
||||||
|
unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||||
@@ -622,9 +664,26 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
if addedByUser {
|
||||||
|
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get setup key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// we validate at the end to not block the setup key for too long
|
||||||
|
if !sk.IsValid() {
|
||||||
|
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||||
@@ -632,39 +691,44 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addedByUser {
|
|
||||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err == nil {
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUniqueConstraintError(err) {
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
|
||||||
if newPeer == nil {
|
if newPeer == nil {
|
||||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
opEvent.TargetID = newPeer.ID
|
||||||
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||||
|
if !addedByUser {
|
||||||
|
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||||
|
}
|
||||||
|
|
||||||
unlock()
|
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||||
unlock = nil
|
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||||
@@ -673,23 +737,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
|
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
|
||||||
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
|
ip = ip.To4()
|
||||||
|
|
||||||
|
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
|
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
//nolint:nilerr
|
||||||
|
return dnsName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
|
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nextIp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
@@ -838,7 +900,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
if login.UserID != "" {
|
if login.UserID != "" {
|
||||||
if peer.UserID != login.UserID {
|
if peer.UserID != login.UserID {
|
||||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||||
return status.Errorf(status.Unauthenticated, "invalid user")
|
return status.NewPeerLoginMismatchError()
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
|
changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
|
||||||
@@ -1034,13 +1096,47 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
var proxyNetworkMap *types.NetworkMap
|
||||||
|
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok = proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
networkMap.Merge(proxyNetworkMap)
|
networkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After generating the network map, store it as a precomputed blob in the DB
|
||||||
|
sqlStore, ok = am.Store.(*store.SqlStore)
|
||||||
|
if ok {
|
||||||
|
db := sqlStore.GetDB()
|
||||||
|
data, err := types.SerializeNetworkMap(networkMap)
|
||||||
|
if err == nil {
|
||||||
|
record := &types.NetworkMapRecord{
|
||||||
|
PeerID: peer.ID,
|
||||||
|
AccountID: account.Id,
|
||||||
|
MapJSON: data,
|
||||||
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
err = types.SaveNetworkMapRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to store precomputed network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Warnf("failed to serialize network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.Now()
|
||||||
|
update := toSyncResponse(ctx, nil, peer, nil, nil, networkMap, am.GetDNSDomain(account.Settings), postureChecks, nil, account.Settings, extraSetting)
|
||||||
|
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||||
|
|
||||||
|
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: networkMap})
|
||||||
|
|
||||||
return peer, networkMap, postureChecks, nil
|
return peer, networkMap, postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1087,7 +1183,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error
|
|||||||
}
|
}
|
||||||
if peer.UserID != loginUserID {
|
if peer.UserID != loginUserID {
|
||||||
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||||
return status.Errorf(status.Unauthenticated, "can't login with this credentials")
|
return status.NewPeerLoginMismatchError()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1220,12 +1316,35 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
|||||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
var proxyNetworkMap *types.NetworkMap
|
||||||
|
proxyNetworkMap, ok = proxyNetworkMaps[p.ID]
|
||||||
if ok {
|
if ok {
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||||
}
|
}
|
||||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
|
// Store the precomputed network map in the DB
|
||||||
|
sqlStore, ok = am.Store.(*store.SqlStore)
|
||||||
|
if ok {
|
||||||
|
db := sqlStore.GetDB()
|
||||||
|
data, err := types.SerializeNetworkMap(remotePeerNetworkMap)
|
||||||
|
if err == nil {
|
||||||
|
record := &types.NetworkMapRecord{
|
||||||
|
PeerID: p.ID,
|
||||||
|
AccountID: account.Id,
|
||||||
|
MapJSON: data,
|
||||||
|
Serial: remotePeerNetworkMap.Network.CurrentSerial(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
err = types.SaveNetworkMapRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to store precomputed network map for peer %s: %v", p.ID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Warnf("failed to serialize network map for peer %s: %v", p.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
||||||
@@ -1240,8 +1359,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
|||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
if am.metrics != nil {
|
if am.metrics != nil {
|
||||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||||
@@ -1307,11 +1424,32 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
var proxyNetworkMap *types.NetworkMap
|
||||||
|
proxyNetworkMap, ok = proxyNetworkMaps[peer.ID]
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
|
||||||
if ok {
|
if ok {
|
||||||
|
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||||
|
|
||||||
|
// Store the precomputed network map in the DB
|
||||||
|
sqlStore, ok = am.Store.(*store.SqlStore)
|
||||||
|
if ok {
|
||||||
|
db := sqlStore.GetDB()
|
||||||
|
data, err := types.SerializeNetworkMap(remotePeerNetworkMap)
|
||||||
|
if err == nil {
|
||||||
|
record := &types.NetworkMapRecord{
|
||||||
|
PeerID: peer.ID,
|
||||||
|
AccountID: account.Id,
|
||||||
|
MapJSON: data,
|
||||||
|
Serial: remotePeerNetworkMap.Network.CurrentSerial(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
err = types.SaveNetworkMapRecord(db, record)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to store precomputed network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Warnf("failed to serialize network map for peer %s: %v", peer.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||||
@@ -1322,6 +1460,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
|||||||
|
|
||||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings)
|
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings)
|
||||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
@@ -1477,19 +1616,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
|||||||
return groupIDs, err
|
return groupIDs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
|
|
||||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
existingLabels := make(types.LookupMap)
|
|
||||||
for _, label := range dnsLabels {
|
|
||||||
existingLabels[label] = struct{}{}
|
|
||||||
}
|
|
||||||
return existingLabels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
// in an active DNS, route, or ACL configuration.
|
// in an active DNS, route, or ACL configuration.
|
||||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||||
|
|||||||
@@ -20,14 +20,14 @@ type Peer struct {
|
|||||||
// WireGuard public key
|
// WireGuard public key
|
||||||
Key string `gorm:"index"`
|
Key string `gorm:"index"`
|
||||||
// IP address of the Peer
|
// IP address of the Peer
|
||||||
IP net.IP `gorm:"serializer:json"`
|
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||||
// Meta is a Peer system meta data
|
// Meta is a Peer system meta data
|
||||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
// Name is peer's name (machine name)
|
// Name is peer's name (machine name)
|
||||||
Name string
|
Name string
|
||||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
DNSLabel string
|
DNSLabel string // uniqueness index per accountID (check migrations)
|
||||||
// Status peer's management connection status
|
// Status peer's management connection status
|
||||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||||
// The user ID that registered the peer
|
// The user ID that registered the peer
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,11 +21,13 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
|
||||||
@@ -1373,6 +1377,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
existingSetupKeyID string
|
existingSetupKeyID string
|
||||||
expectedGroupIDsInAccount []string
|
expectedGroupIDsInAccount []string
|
||||||
expectAddPeerError bool
|
expectAddPeerError bool
|
||||||
|
errorType status.Type
|
||||||
expectedErrorMsgSubstring string
|
expectedErrorMsgSubstring string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -1385,13 +1390,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
name: "Failed registration with setup key not allowing extra DNS labels",
|
name: "Failed registration with setup key not allowing extra DNS labels",
|
||||||
existingSetupKeyID: "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
existingSetupKeyID: "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||||
expectAddPeerError: true,
|
expectAddPeerError: true,
|
||||||
|
errorType: status.PreconditionFailed,
|
||||||
expectedErrorMsgSubstring: "setup key doesn't allow extra DNS labels",
|
expectedErrorMsgSubstring: "setup key doesn't allow extra DNS labels",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Absent setup key",
|
name: "Absent setup key",
|
||||||
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
||||||
expectAddPeerError: true,
|
expectAddPeerError: true,
|
||||||
expectedErrorMsgSubstring: "failed adding new peer: account not found",
|
errorType: status.NotFound,
|
||||||
|
expectedErrorMsgSubstring: "couldn't add peer: setup key is invalid",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1416,6 +1423,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
if tc.expectAddPeerError {
|
if tc.expectAddPeerError {
|
||||||
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
|
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
|
||||||
assert.Contains(t, err.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
|
assert.Contains(t, err.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
|
||||||
|
e, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Failed to map error")
|
||||||
|
}
|
||||||
|
assert.Equal(t, e.Type(), tc.errorType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2057,10 +2069,14 @@ func Test_DeletePeer(t *testing.T) {
|
|||||||
"peer1": {
|
"peer1": {
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
DNSLabel: "peer1.test",
|
||||||
},
|
},
|
||||||
"peer2": {
|
"peer2": {
|
||||||
ID: "peer2",
|
ID: "peer2",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
DNSLabel: "peer2.test",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
account.Groups = map[string]*types.Group{
|
account.Groups = map[string]*types.Group{
|
||||||
@@ -2090,3 +2106,138 @@ func Test_DeletePeer(t *testing.T) {
|
|||||||
assert.NotContains(t, group.Peers, "peer1")
|
assert.NotContains(t, group.Peers, "peer1")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_IsUniqueConstraintError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
engine types.Engine
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "PostgreSQL uniqueness error",
|
||||||
|
engine: types.PostgresStoreEngine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MySQL uniqueness error",
|
||||||
|
engine: types.MysqlStoreEngine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SQLite uniqueness error",
|
||||||
|
engine: types.SqliteStoreEngine,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: "test-peer-id",
|
||||||
|
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
DNSLabel: "test-peer-dns-label",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
|
||||||
|
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when creating store: %s", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||||
|
result := isUniqueConstraintError(err)
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeer(t *testing.T) {
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := "testaccount"
|
||||||
|
userID := "testuser"
|
||||||
|
|
||||||
|
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
|
||||||
|
const differentHostnames = 50
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make(chan error, totalPeers+differentHostnames)
|
||||||
|
start := make(chan struct{})
|
||||||
|
for i := 0; i < totalPeers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
hostNameID := i % differentHostnames
|
||||||
|
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
Key: "key" + strconv.Itoa(i),
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
|
||||||
|
}
|
||||||
|
|
||||||
|
<-start
|
||||||
|
|
||||||
|
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
|
||||||
|
if err != nil {
|
||||||
|
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
t.Logf("time since start: %s", time.Since(startTime))
|
||||||
|
|
||||||
|
for err := range errs {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||||
|
|
||||||
|
seenIP := make(map[string]bool)
|
||||||
|
for _, p := range account.Peers {
|
||||||
|
ipStr := p.IP.String()
|
||||||
|
if seenIP[ipStr] {
|
||||||
|
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
|
||||||
|
}
|
||||||
|
seenIP[ipStr] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
seenLabel := make(map[string]bool)
|
||||||
|
for _, p := range account.Peers {
|
||||||
|
if seenLabel[p.DNSLabel] {
|
||||||
|
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
|
||||||
|
}
|
||||||
|
seenLabel[p.DNSLabel] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
||||||
|
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||||
|
}
|
||||||
|
|||||||
@@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error {
|
|||||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding
|
||||||
|
func NewAccountOnboardingNotFoundError(accountKey string) error {
|
||||||
|
return Errorf(NotFound, "account onboarding not found: %s", accountKey)
|
||||||
|
}
|
||||||
|
|
||||||
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||||
func NewPeerNotPartOfAccountError() error {
|
func NewPeerNotPartOfAccountError() error {
|
||||||
return Errorf(PermissionDenied, "peer is not part of this account")
|
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||||
@@ -105,11 +110,16 @@ func NewUserBlockedError() error {
|
|||||||
return Errorf(PermissionDenied, "user is blocked")
|
return Errorf(PermissionDenied, "user is blocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
// NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer
|
||||||
func NewPeerNotRegisteredError() error {
|
func NewPeerNotRegisteredError() error {
|
||||||
return Errorf(Unauthenticated, "peer is not registered")
|
return Errorf(Unauthenticated, "peer is not registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPeerLoginMismatchError creates a new Error with Unauthenticated type for a peer that is already registered for another user
|
||||||
|
func NewPeerLoginMismatchError() error {
|
||||||
|
return Errorf(Unauthenticated, "peer is already registered by a different User or a Setup Key")
|
||||||
|
}
|
||||||
|
|
||||||
// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
|
// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
|
||||||
func NewPeerLoginExpiredError() error {
|
func NewPeerLoginExpiredError() error {
|
||||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err)
|
||||||
// if the All group didn't exist we probably don't have routes to update
|
// if the All group didn't exist we probably don't have routes to update
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,17 +92,21 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrate(ctx, db); err != nil {
|
if err := migratePreAuto(ctx, db); err != nil {
|
||||||
return nil, fmt.Errorf("migrate: %w", err)
|
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(
|
err = db.AutoMigrate(
|
||||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||||
|
&types.NetworkMapRecord{}, // <-- Added for precomputed network maps
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||||
|
}
|
||||||
|
if err := migratePostAuto(ctx, db); err != nil {
|
||||||
|
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||||
@@ -725,6 +729,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
|
|||||||
return &accountMeta, nil
|
return &accountMeta, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||||
|
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
|
||||||
|
var accountOnboarding types.AccountOnboarding
|
||||||
|
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewAccountOnboardingNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
|
||||||
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &accountOnboarding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccountOnboarding updates the onboarding information for a specific account.
|
||||||
|
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
|
||||||
|
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||||
|
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -967,7 +997,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
|||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
@@ -975,7 +1005,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
|||||||
|
|
||||||
var labels []string
|
var labels []string
|
||||||
result := tx.Model(&nbpeer.Peer{}).
|
result := tx.Model(&nbpeer.Peer{}).
|
||||||
Where("account_id = ?", accountID).
|
Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
|
||||||
Pluck("dns_label", &labels)
|
Pluck("dns_label", &labels)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -1254,7 +1284,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewSetupKeyNotFoundError(key)
|
return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||||
@@ -1410,7 +1440,11 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
|||||||
// GetAccountPeers retrieves peers for an account.
|
// GetAccountPeers retrieves peers for an account.
|
||||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID)
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
query := tx.Where(accountIDCondition, accountID)
|
||||||
|
|
||||||
if nameFilter != "" {
|
if nameFilter != "" {
|
||||||
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
|
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
|
||||||
@@ -2546,6 +2580,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
|||||||
return &peer, nil
|
return &peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
||||||
|
tx := s.db.WithContext(ctx)
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var peerID string
|
||||||
|
result := tx.Model(&nbpeer.Peer{}).
|
||||||
|
Select("id").
|
||||||
|
// Where(" = ?", hostname).
|
||||||
|
Where("account_id = ? AND dns_label = ?", accountID, hostname).
|
||||||
|
Limit(1).
|
||||||
|
Scan(&peerID)
|
||||||
|
|
||||||
|
if peerID == "" {
|
||||||
|
return "", gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerID, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
result := s.db.Model(&types.Account{}).
|
result := s.db.Model(&types.Account{}).
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -353,9 +354,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
|||||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
o, err := store.GetAccountOnboarding(context.Background(), account.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, o.AccountID, account.Id)
|
||||||
|
|
||||||
err = store.DeleteAccount(context.Background(), account)
|
err = store.DeleteAccount(context.Background(), account)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = store.GetAccountOnboarding(context.Background(), account.Id)
|
||||||
|
require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding")
|
||||||
|
|
||||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||||
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
||||||
}
|
}
|
||||||
@@ -413,12 +421,21 @@ func Test_GetAccount(t *testing.T) {
|
|||||||
account, err := store.GetAccount(context.Background(), id)
|
account, err := store.GetAccount(context.Background(), id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, account.Id, "account id should match")
|
require.Equal(t, id, account.Id, "account id should match")
|
||||||
|
require.Equal(t, false, account.Onboarding.OnboardingFlowPending)
|
||||||
|
|
||||||
|
id = "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||||
|
|
||||||
|
account, err = store.GetAccount(context.Background(), id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, id, account.Id, "account id should match")
|
||||||
|
require.Equal(t, true, account.Onboarding.OnboardingFlowPending)
|
||||||
|
|
||||||
_, err = store.GetAccount(context.Background(), "non-existing-account")
|
_, err = store.GetAccount(context.Background(), "non-existing-account")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
parsedErr, ok := status.FromError(err)
|
parsedErr, ok := status.FromError(err)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -630,7 +647,7 @@ func TestMigrate(t *testing.T) {
|
|||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUp)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on empty db")
|
require.NoError(t, err, "Migration should not fail on empty db")
|
||||||
|
|
||||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||||
@@ -685,10 +702,10 @@ func TestMigrate(t *testing.T) {
|
|||||||
err = store.(*SqlStore).db.Save(rt).Error
|
err = store.(*SqlStore).db.Save(rt).Error
|
||||||
require.NoError(t, err, "Failed to insert Gob data")
|
require.NoError(t, err, "Failed to insert Gob data")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on gob populated db")
|
require.NoError(t, err, "Migration should not fail on gob populated db")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||||
|
|
||||||
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
||||||
@@ -704,10 +721,10 @@ func TestMigrate(t *testing.T) {
|
|||||||
err = store.(*SqlStore).db.Save(nRT).Error
|
err = store.(*SqlStore).db.Save(nRT).Error
|
||||||
require.NoError(t, err, "Failed to insert json nil slice data")
|
require.NoError(t, err, "Failed to insert json nil slice data")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
||||||
|
|
||||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -950,6 +967,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
peer1 := &nbpeer.Peer{
|
peer1 := &nbpeer.Peer{
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1",
|
||||||
IP: net.IP{1, 1, 1, 1},
|
IP: net.IP{1, 1, 1, 1},
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
@@ -961,8 +979,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{
|
peer2 := &nbpeer.Peer{
|
||||||
ID: "peer2",
|
ID: "peer1second",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1-1",
|
||||||
IP: net.IP{2, 2, 2, 2},
|
IP: net.IP{2, 2, 2, 2},
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
@@ -972,26 +991,59 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ip2 := net.IP{2, 2, 2, 2}.To16()
|
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||||
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
peerHostname := "peer1"
|
||||||
|
|
||||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{}, labels)
|
assert.Equal(t, []string{}, labels)
|
||||||
|
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1",
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"peer1"}, labels)
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer1second",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
DNSLabel: "peer1-1",
|
||||||
|
IP: net.IP{2, 2, 2, 2},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected := []string{"peer1", "peer1-1"}
|
||||||
|
sort.Strings(expected)
|
||||||
|
sort.Strings(labels)
|
||||||
|
assert.Equal(t, expected, labels)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeerWithSameDnsLabel(t *testing.T) {
|
||||||
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
peer1 := &nbpeer.Peer{
|
peer1 := &nbpeer.Peer{
|
||||||
ID: "peer1",
|
ID: "peer1",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
@@ -1000,21 +1052,39 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
|||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
|
||||||
|
|
||||||
peer2 := &nbpeer.Peer{
|
peer2 := &nbpeer.Peer{
|
||||||
ID: "peer2",
|
ID: "peer1second",
|
||||||
AccountID: existingAccountID,
|
AccountID: existingAccountID,
|
||||||
DNSLabel: "peer2.domain.test",
|
DNSLabel: "peer1.domain.test",
|
||||||
}
|
}
|
||||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AddPeerWithSameIP(t *testing.T) {
|
||||||
|
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||||
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer1second",
|
||||||
|
AccountID: existingAccountID,
|
||||||
|
IP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||||
@@ -2042,6 +2112,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
|||||||
PeerInactivityExpirationEnabled: false,
|
PeerInactivityExpirationEnabled: false,
|
||||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||||
},
|
},
|
||||||
|
Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := acc.AddAllGroup(false); err != nil {
|
if err := acc.AddAllGroup(false); err != nil {
|
||||||
@@ -3386,6 +3457,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
|
|||||||
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
|
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetAccountOnboarding(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||||
|
a, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Onboarding: %+v", a.Onboarding)
|
||||||
|
err = store.SaveAccount(context.Background(), a)
|
||||||
|
require.NoError(t, err)
|
||||||
|
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, onboarding)
|
||||||
|
require.Equal(t, accountID, onboarding.AccountID)
|
||||||
|
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveAccountOnboarding(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Run("New onboarding should be saved correctly", func(t *testing.T) {
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
onboarding := &types.AccountOnboarding{
|
||||||
|
AccountID: accountID,
|
||||||
|
SignupFormPending: true,
|
||||||
|
OnboardingFlowPending: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveAccountOnboarding(context.Background(), onboarding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
|
||||||
|
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Existing onboarding should be updated correctly", func(t *testing.T) {
|
||||||
|
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||||
|
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending
|
||||||
|
onboarding.SignupFormPending = !onboarding.SignupFormPending
|
||||||
|
|
||||||
|
err = store.SaveAccountOnboarding(context.Background(), onboarding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
|
||||||
|
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlStore_GetAnyAccountID(t *testing.T) {
|
func TestSqlStore_GetAnyAccountID(t *testing.T) {
|
||||||
t.Run("should return account ID when accounts exist", func(t *testing.T) {
|
t.Run("should return account ID when accounts exist", func(t *testing.T) {
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ type Store interface {
|
|||||||
GetAllAccounts(ctx context.Context) []*types.Account
|
GetAllAccounts(ctx context.Context) []*types.Account
|
||||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
|
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
|
||||||
|
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
|
||||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||||
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
||||||
@@ -74,6 +75,7 @@ type Store interface {
|
|||||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
|
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
|
||||||
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
||||||
|
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
|
||||||
|
|
||||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||||
@@ -117,7 +119,7 @@ type Store interface {
|
|||||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||||
@@ -193,6 +195,7 @@ type Store interface {
|
|||||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
||||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
||||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||||
|
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -234,9 +237,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type
|
|||||||
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
||||||
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
||||||
|
|
||||||
// Attempt to migrate from JSON store to SQLite
|
// Attempt to migratePreAuto from JSON store to SQLite
|
||||||
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
|
log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
|
||||||
kind = types.FileStoreEngine
|
kind = types.FileStoreEngine
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -280,9 +283,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrate migrates the SQLite database to the latest schema
|
// migratePreAuto migrates the SQLite database to the latest schema
|
||||||
func migrate(ctx context.Context, db *gorm.DB) error {
|
func migratePreAuto(ctx context.Context, db *gorm.DB) error {
|
||||||
migrations := getMigrations(ctx)
|
migrations := getMigrationsPreAuto(ctx)
|
||||||
|
|
||||||
for _, m := range migrations {
|
for _, m := range migrations {
|
||||||
if err := m(db); err != nil {
|
if err := m(db); err != nil {
|
||||||
@@ -293,7 +296,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMigrations(ctx context.Context) []migrationFunc {
|
func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||||
return []migrationFunc{
|
return []migrationFunc{
|
||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
||||||
@@ -329,6 +332,28 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
|||||||
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||||
|
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||||
|
migrations := getMigrationsPostAuto(ctx)
|
||||||
|
|
||||||
|
for _, m := range migrations {
|
||||||
|
if err := m(db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||||
|
return []migrationFunc{
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
|
||||||
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
||||||
@@ -577,7 +602,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
|||||||
|
|
||||||
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
||||||
if fsStoreAccounts != sqliteStoreAccounts {
|
if fsStoreAccounts != sqliteStoreAccounts {
|
||||||
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d",
|
return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
|
||||||
fsStoreAccounts, sqliteStoreAccounts)
|
fsStoreAccounts, sqliteStoreAccounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
6
management/server/testdata/store.sql
vendored
6
management/server/testdata/store.sql
vendored
@@ -1,4 +1,5 @@
|
|||||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||||
|
CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`));
|
||||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
@@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
|||||||
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
||||||
|
|
||||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
|
INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||||
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||||
@@ -52,4 +54,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
|
|||||||
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
||||||
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
||||||
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
||||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
|
|||||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
|||||||
@@ -83,10 +83,10 @@ type Account struct {
|
|||||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||||
// Settings is a dictionary of Account settings
|
// Settings is a dictionary of Account settings
|
||||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
|
|
||||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||||
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
||||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subclass used in gorm to only load network and not whole account
|
// Subclass used in gorm to only load network and not whole account
|
||||||
@@ -104,6 +104,20 @@ type AccountSettings struct {
|
|||||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccountOnboarding struct {
|
||||||
|
AccountID string `gorm:"primaryKey"`
|
||||||
|
OnboardingFlowPending bool
|
||||||
|
SignupFormPending bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEqual compares two AccountOnboarding objects and returns true if they are equal
|
||||||
|
func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||||
|
return o.OnboardingFlowPending == onboarding.OnboardingFlowPending &&
|
||||||
|
o.SignupFormPending == onboarding.SignupFormPending
|
||||||
|
}
|
||||||
|
|
||||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||||
@@ -866,6 +880,7 @@ func (a *Account) Copy() *Account {
|
|||||||
Networks: nets,
|
Networks: nets,
|
||||||
NetworkRouters: networkRouters,
|
NetworkRouters: networkRouters,
|
||||||
NetworkResources: networkResources,
|
NetworkResources: networkResources,
|
||||||
|
Onboarding: a.Onboarding,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -161,24 +162,65 @@ func (n *Network) Copy() *Network {
|
|||||||
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
||||||
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||||
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||||
takenIPMap := make(map[string]struct{})
|
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||||
takenIPMap[ipNet.IP.String()] = struct{}{}
|
totalIPs := uint32(1 << SubnetSize)
|
||||||
|
|
||||||
|
taken := make(map[uint32]struct{}, len(takenIps)+1)
|
||||||
|
taken[baseIP] = struct{}{} // reserve network IP
|
||||||
|
taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP
|
||||||
|
|
||||||
for _, ip := range takenIps {
|
for _, ip := range takenIps {
|
||||||
takenIPMap[ip.String()] = struct{}{}
|
taken[ipToUint32(ip)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
maxAttempts := (int(totalIPs) - len(taken)) / 100
|
||||||
|
|
||||||
if len(ips) == 0 {
|
for i := 0; i < maxAttempts; i++ {
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||||
|
candidate := baseIP + offset
|
||||||
|
if _, exists := taken[candidate]; !exists {
|
||||||
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pick a random IP
|
for offset := uint32(1); offset < totalIPs-1; offset++ {
|
||||||
s := rand.NewSource(time.Now().Unix())
|
candidate := baseIP + offset
|
||||||
r := rand.New(s)
|
if _, exists := taken[candidate]; !exists {
|
||||||
intn := r.Intn(len(ips))
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ips[intn], nil
|
return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) {
|
||||||
|
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||||
|
|
||||||
|
ones, bits := ipNet.Mask.Size()
|
||||||
|
hostBits := bits - ones
|
||||||
|
|
||||||
|
totalIPs := uint32(1 << hostBits)
|
||||||
|
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||||
|
|
||||||
|
candidate := baseIP + offset
|
||||||
|
return uint32ToIP(candidate), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipToUint32(ip net.IP) uint32 {
|
||||||
|
ip = ip.To4()
|
||||||
|
if len(ip) < 4 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return binary.BigEndian.Uint32(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint32ToIP(n uint32) net.IP {
|
||||||
|
ip := make(net.IP, 4)
|
||||||
|
binary.BigEndian.PutUint32(ip, n)
|
||||||
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
||||||
|
|||||||
17
management/server/types/network_map_helpers.go
Normal file
17
management/server/types/network_map_helpers.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SerializeNetworkMap serializes a NetworkMap to JSON
|
||||||
|
func SerializeNetworkMap(nm *NetworkMap) ([]byte, error) {
|
||||||
|
return json.Marshal(nm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeserializeNetworkMap deserializes JSON data into a NetworkMap
|
||||||
|
func DeserializeNetworkMap(data []byte) (*NetworkMap, error) {
|
||||||
|
var nm NetworkMap
|
||||||
|
err := json.Unmarshal(data, &nm)
|
||||||
|
return &nm, err
|
||||||
|
}
|
||||||
39
management/server/types/network_map_record.go
Normal file
39
management/server/types/network_map_record.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/datatypes"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkMapRecord stores a precomputed network map for a peer
|
||||||
|
// MapJSON is stored as jsonb (Postgres), json (MySQL), or text (SQLite)
|
||||||
|
type NetworkMapRecord struct {
|
||||||
|
PeerID string `gorm:"primaryKey"`
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
MapJSON datatypes.JSON `gorm:"type:jsonb"` // GORM will use the right type for your DB
|
||||||
|
Serial uint64
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName sets the table name for GORM
|
||||||
|
// This ensures the table is named consistently across all supported databases.
|
||||||
|
func (NetworkMapRecord) TableName() string {
|
||||||
|
return "network_map_records"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveNetworkMapRecord stores or updates a NetworkMapRecord in the database
|
||||||
|
func SaveNetworkMapRecord(db *gorm.DB, record *NetworkMapRecord) error {
|
||||||
|
return db.Save(record).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNetworkMapRecord retrieves a NetworkMapRecord by peer ID
|
||||||
|
func GetNetworkMapRecord(db *gorm.DB, peerID string) (*NetworkMapRecord, error) {
|
||||||
|
var record NetworkMapRecord
|
||||||
|
err := db.First(&record, "peer_id = ?", peerID).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
102
management/server/types/network_map_record_test.go
Normal file
102
management/server/types/network_map_record_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNetworkMapRecordCRUD(t *testing.T) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, db.AutoMigrate(&NetworkMapRecord{}))
|
||||||
|
|
||||||
|
record := &NetworkMapRecord{
|
||||||
|
PeerID: "peer1",
|
||||||
|
AccountID: "account1",
|
||||||
|
MapJSON: []byte(`{"Peers":[],"Network":null}`),
|
||||||
|
Serial: 1,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
require.NoError(t, SaveNetworkMapRecord(db, record))
|
||||||
|
|
||||||
|
fetched, err := GetNetworkMapRecord(db, "peer1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, record.PeerID, fetched.PeerID)
|
||||||
|
require.Equal(t, record.AccountID, fetched.AccountID)
|
||||||
|
require.Equal(t, record.Serial, fetched.Serial)
|
||||||
|
require.Equal(t, record.MapJSON, fetched.MapJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate a normalized structure for comparison
|
||||||
|
// In a real scenario, this would be split across multiple tables
|
||||||
|
// Here, we just use a struct for benchmarking
|
||||||
|
|
||||||
|
type NormalizedPeer struct {
|
||||||
|
ID string
|
||||||
|
AccountID string
|
||||||
|
Name string
|
||||||
|
IP string
|
||||||
|
}
|
||||||
|
|
||||||
|
type NormalizedNetworkMap struct {
|
||||||
|
PeerID string
|
||||||
|
Peers []NormalizedPeer
|
||||||
|
Serial uint64
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNetworkMapRecord_StoreAndRetrieve_JSON(b *testing.B) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
db.AutoMigrate(&NetworkMapRecord{})
|
||||||
|
|
||||||
|
record := &NetworkMapRecord{
|
||||||
|
PeerID: "peer1",
|
||||||
|
AccountID: "account1",
|
||||||
|
MapJSON: []byte(`{"Peers":[{"ID":"p1","AccountID":"account1","Name":"peer1","IP":"10.0.0.1"}],"Network":null}`),
|
||||||
|
Serial: 1,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
record.Serial = uint64(i)
|
||||||
|
record.UpdatedAt = time.Now()
|
||||||
|
if err := SaveNetworkMapRecord(db, record); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err := GetNetworkMapRecord(db, "peer1")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNetworkMapRecord_StoreAndRetrieve_Normalized(b *testing.B) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
db.AutoMigrate(&NormalizedPeer{})
|
||||||
|
|
||||||
|
peers := []NormalizedPeer{{ID: "p1", AccountID: "account1", Name: "peer1", IP: "10.0.0.1"}}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for _, peer := range peers {
|
||||||
|
if err := db.Save(&peer).Error; err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var fetched []NormalizedPeer
|
||||||
|
if err := db.Find(&fetched, "account_id = ?", "account1").Error; err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -550,7 +550,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process user update: %w", err)
|
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||||
}
|
}
|
||||||
usersToSave = append(usersToSave, updatedUser)
|
usersToSave = append(usersToSave, updatedUser)
|
||||||
addUserEvents = append(addUserEvents, userEvents...)
|
addUserEvents = append(addUserEvents, userEvents...)
|
||||||
|
|||||||
29
monotime/time.go
Normal file
29
monotime/time.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package monotime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
baseWallTime time.Time
|
||||||
|
baseWallNano int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
baseWallTime = time.Now()
|
||||||
|
baseWallNano = baseWallTime.UnixNano()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now returns the current time as Unix nanoseconds (int64).
|
||||||
|
// It uses monotonic time measurement from the base time to ensure
|
||||||
|
// the returned value increases monotonically and is not affected
|
||||||
|
// by system clock adjustments.
|
||||||
|
//
|
||||||
|
// Performance optimization: By capturing the base wall time once at startup
|
||||||
|
// and using time.Since() for elapsed calculation, this avoids repeated
|
||||||
|
// time.Now() calls and leverages Go's internal monotonic clock for
|
||||||
|
// efficient duration measurement.
|
||||||
|
func Now() int64 {
|
||||||
|
elapsed := time.Since(baseWallTime)
|
||||||
|
return baseWallNano + int64(elapsed)
|
||||||
|
}
|
||||||
20
monotime/time_test.go
Normal file
20
monotime/time_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package monotime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkMonotimeNow(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTimeNow(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,6 +29,7 @@ const (
|
|||||||
Body_ANSWER Body_Type = 1
|
Body_ANSWER Body_Type = 1
|
||||||
Body_CANDIDATE Body_Type = 2
|
Body_CANDIDATE Body_Type = 2
|
||||||
Body_MODE Body_Type = 4
|
Body_MODE Body_Type = 4
|
||||||
|
Body_GO_IDLE Body_Type = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
// Enum value maps for Body_Type.
|
// Enum value maps for Body_Type.
|
||||||
@@ -38,12 +39,14 @@ var (
|
|||||||
1: "ANSWER",
|
1: "ANSWER",
|
||||||
2: "CANDIDATE",
|
2: "CANDIDATE",
|
||||||
4: "MODE",
|
4: "MODE",
|
||||||
|
5: "GO_IDLE",
|
||||||
}
|
}
|
||||||
Body_Type_value = map[string]int32{
|
Body_Type_value = map[string]int32{
|
||||||
"OFFER": 0,
|
"OFFER": 0,
|
||||||
"ANSWER": 1,
|
"ANSWER": 1,
|
||||||
"CANDIDATE": 2,
|
"CANDIDATE": 2,
|
||||||
"MODE": 4,
|
"MODE": 4,
|
||||||
|
"GO_IDLE": 5,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,7 +228,7 @@ type Body struct {
|
|||||||
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
|
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
|
||||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||||
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
|
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
|
||||||
// relayServerAddress is an IP:port of the relay server
|
// relayServerAddress is url of the relay server
|
||||||
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
|
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,7 +443,7 @@ var file_signalexchange_proto_rawDesc = []byte{
|
|||||||
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
||||||
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
||||||
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
||||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||||
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
||||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
||||||
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
||||||
@@ -463,33 +466,34 @@ var file_signalexchange_proto_rawDesc = []byte{
|
|||||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
|
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
|
||||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
|
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
|
||||||
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
|
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
|
||||||
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
|
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
|
||||||
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
|
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
|
||||||
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
|
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
|
||||||
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
|
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b,
|
||||||
0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
|
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d,
|
||||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
|
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20,
|
||||||
0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d,
|
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01,
|
||||||
0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69,
|
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52,
|
||||||
0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75,
|
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28,
|
||||||
0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65,
|
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65,
|
||||||
0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72,
|
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
|
||||||
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
|
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65,
|
||||||
0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
|
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18,
|
||||||
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01,
|
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||||
0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53,
|
||||||
0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
|
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a,
|
||||||
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
|
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67,
|
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||||
0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72,
|
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
|
||||||
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59,
|
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||||
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
|
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43,
|
||||||
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73,
|
||||||
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
|
||||||
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
|
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
|
||||||
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
|
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
|
||||||
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
|
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||||
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||||
|
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ message Body {
|
|||||||
ANSWER = 1;
|
ANSWER = 1;
|
||||||
CANDIDATE = 2;
|
CANDIDATE = 2;
|
||||||
MODE = 4;
|
MODE = 4;
|
||||||
|
GO_IDLE = 5;
|
||||||
}
|
}
|
||||||
Type type = 1;
|
Type type = 1;
|
||||||
string payload = 2;
|
string payload = 2;
|
||||||
|
|||||||
Reference in New Issue
Block a user