Compare commits

...

7 Commits

72 changed files with 2698 additions and 682 deletions

57
client/cmd/logout.go Normal file
View File

@@ -0,0 +1,57 @@
package cmd
import (
"context"
"fmt"
"os/user"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
var logoutCmd = &cobra.Command{
Use: "logout",
Short: "logout from the Netbird Management Service and delete peer",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %v", err)
}
defer conn.Close()
daemonClient := proto.NewDaemonServiceClient(conn)
req := &proto.LogoutRequest{}
if profileName != "" {
req.ProfileName = &profileName
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
username := currUser.Username
req.Username = &username
}
if _, err := daemonClient.Logout(ctx, req); err != nil {
return fmt.Errorf("logout: %v", err)
}
cmd.Println("Logged out successfully")
return nil
},
}
func init() {
logoutCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
}

View File

@@ -3,9 +3,8 @@ package cmd
import (
"context"
"fmt"
"time"
"os/user"
"time"
"github.com/spf13/cobra"
@@ -22,10 +21,11 @@ var profileCmd = &cobra.Command{
}
var profileListCmd = &cobra.Command{
Use: "list",
Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`,
RunE: listProfilesFunc,
Use: "list",
Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`,
Aliases: []string{"ls"},
RunE: listProfilesFunc,
}
var profileAddCmd = &cobra.Command{

View File

@@ -133,6 +133,7 @@ func init() {
rootCmd.AddCommand(downCmd)
rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)

View File

@@ -221,7 +221,7 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -243,7 +243,7 @@ func (t *ICMPTracker) track(
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -294,7 +294,7 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -211,7 +211,7 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.logger.Trace2("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -240,7 +240,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
@@ -262,7 +262,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
@@ -340,17 +340,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, p
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
@@ -438,7 +438,7 @@ func (t *TCPTracker) cleanup() {
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change

View File

@@ -116,7 +116,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.logger.Trace2("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -601,7 +601,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
return false
}
@@ -727,13 +727,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
return true
}
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
@@ -741,7 +741,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
@@ -766,7 +766,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
@@ -807,7 +807,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
}
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
m.logger.Error1("Failed to inject local packet: %v", err)
}
// don't process this packet further
@@ -819,7 +819,7 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
@@ -835,7 +835,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
@@ -863,7 +863,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err)
m.logger.Error1("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
}
}
@@ -901,7 +901,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err)
m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false
}

View File

@@ -57,7 +57,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
e.logger.Error1("CreateOutboundPacket: %v", err)
continue
}
written++

View File

@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
}()
@@ -52,11 +52,11 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
@@ -72,7 +72,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
@@ -80,7 +80,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to read ICMP response: %v", err)
}
return 0
}
@@ -101,12 +101,12 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)

View File

@@ -38,7 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
return
}
@@ -47,9 +47,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
f.logger.Error1("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
@@ -61,7 +61,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
@@ -75,10 +75,10 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
@@ -127,7 +127,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -78,10 +78,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(id), err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
conn.ep.Close()
@@ -112,10 +112,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
f.logger.Debug2("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
}
idle.conn.ep.Close()
@@ -124,7 +124,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
@@ -143,7 +143,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return
}
@@ -160,7 +160,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
}
@@ -169,9 +169,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
f.logger.Debug1("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
@@ -194,10 +194,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
@@ -205,7 +205,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
@@ -220,10 +220,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
pConn.cancel()
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
f.logger.Error2("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
f.logger.Error2("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64
@@ -263,7 +263,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -44,7 +44,12 @@ var levelStrings = map[Level]string{
type logMessage struct {
level Level
format string
args []any
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
}
// Logger is a high-performance, non-blocking logger
@@ -89,62 +94,198 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelError, format: format}:
default:
}
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
func (l *Logger) Warn(format string) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format}:
default:
}
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
func (l *Logger) Info(format string) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelInfo, format: format}:
default:
}
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
func (l *Logger) Debug(format string) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format}:
default:
}
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
func (l *Logger) Trace(format string) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
default:
}
}
}
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
}
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
}
}
}
}
}
}
*buf = append(*buf, msg...)
var formatted string
switch argCount {
case 0:
formatted = msg.format
case 1:
formatted = fmt.Sprintf(msg.format, msg.arg1)
case 2:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2)
case 3:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3)
case 4:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4)
case 5:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
}
*buf = append(*buf, formatted...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
@@ -157,7 +298,7 @@ func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
l.formatMessage(bufp, msg)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
@@ -249,4 +390,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done:
return nil
}
}
}

View File

@@ -19,22 +19,17 @@ func (d *discard) Write(p []byte) (n int, err error) {
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
@@ -52,7 +47,7 @@ func BenchmarkLogger(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
})
@@ -62,7 +57,7 @@ func BenchmarkLogger(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
logger.Trace6("Complex trace: proto=%s dir=%s flags=%d seq=%d ack=%d size=%d", protocol, direction, flags, sequence, acknowledged, 1460)
}
})
}
@@ -72,7 +67,6 @@ func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
@@ -82,7 +76,7 @@ func BenchmarkLoggerParallel(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
})
}
@@ -92,7 +86,6 @@ func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
@@ -102,7 +95,7 @@ func BenchmarkLoggerBurst(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
logger.Trace5("TCP connection %s:%d → %s:%d state %d", srcIP, srcPort, dstIP, dstPort, state)
}
}
}

View File

@@ -211,11 +211,11 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
m.logger.Error1("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
@@ -237,11 +237,11 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
m.logger.Error1("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}

View File

@@ -171,7 +171,7 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
return nil, fmt.Errorf("parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))

View File

@@ -95,7 +95,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
return fmt.Errorf("close remote conn: %w", err)
}
return nil
}

View File

@@ -861,15 +861,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return errors.New("wireguard interface is not initialized")
}
// Cannot update the IP address without restarting the engine because
// the firewall, route manager, and other components cache the old address
if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
err := e.wgInterface.UpdateAddr(conf.Address)
if err != nil {
return err
}
e.config.WgAddr = conf.Address
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
}
if conf.GetSshConfig() != nil {
@@ -880,7 +875,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.config.WgAddr
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()

View File

@@ -13,7 +13,8 @@ import (
)
const (
defaultProfileName = "default"
DefaultProfileName = "default"
defaultProfileName = DefaultProfileName // Keep for backward compatibility
activeProfileStateFilename = "active_profile.txt"
)

View File

@@ -77,8 +77,8 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{105, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{105, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}

View File

@@ -4342,6 +4342,94 @@ func (x *GetActiveProfileResponse) GetUsername() string {
return ""
}
type LogoutRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LogoutRequest) Reset() {
*x = LogoutRequest{}
mi := &file_daemon_proto_msgTypes[65]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *LogoutRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*LogoutRequest) ProtoMessage() {}
func (x *LogoutRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[65]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead.
func (*LogoutRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{65}
}
func (x *LogoutRequest) GetProfileName() string {
if x != nil && x.ProfileName != nil {
return *x.ProfileName
}
return ""
}
func (x *LogoutRequest) GetUsername() string {
if x != nil && x.Username != nil {
return *x.Username
}
return ""
}
type LogoutResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LogoutResponse) Reset() {
*x = LogoutResponse{}
mi := &file_daemon_proto_msgTypes[66]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *LogoutResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*LogoutResponse) ProtoMessage() {}
func (x *LogoutResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[66]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead.
func (*LogoutResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{66}
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4352,7 +4440,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[66]
mi := &file_daemon_proto_msgTypes[68]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4364,7 +4452,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[66]
mi := &file_daemon_proto_msgTypes[68]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4778,7 +4866,13 @@ const file_daemon_proto_rawDesc = "" +
"\x17GetActiveProfileRequest\"X\n" +
"\x18GetActiveProfileResponse\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername*b\n" +
"\busername\x18\x02 \x01(\tR\busername\"t\n" +
"\rLogoutRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\x10\n" +
"\x0eLogoutResponse*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -4787,7 +4881,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\x84\x0f\n" +
"\x05TRACE\x10\a2\xbf\x0f\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -4817,7 +4911,8 @@ const file_daemon_proto_rawDesc = "" +
"AddProfile\x12\x19.daemon.AddProfileRequest\x1a\x1a.daemon.AddProfileResponse\"\x00\x12N\n" +
"\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" +
"\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -4832,7 +4927,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 68)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 70)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -4902,18 +4997,20 @@ var file_daemon_proto_goTypes = []any{
(*Profile)(nil), // 65: daemon.Profile
(*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest
(*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse
nil, // 68: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 69: daemon.PortInfo.Range
nil, // 70: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 71: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 72: google.protobuf.Timestamp
(*LogoutRequest)(nil), // 68: daemon.LogoutRequest
(*LogoutResponse)(nil), // 69: daemon.LogoutResponse
nil, // 70: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 71: daemon.PortInfo.Range
nil, // 72: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 73: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 74: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
71, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
73, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
72, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
72, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
71, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
74, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
74, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
73, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -4922,8 +5019,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
68, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
69, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
70, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
71, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -4934,10 +5031,10 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
72, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
70, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
74, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
72, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
71, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
73, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -4966,34 +5063,36 @@ var file_daemon_proto_depIdxs = []int32{
61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
5, // 57: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 58: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 59: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 60: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 61: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 62: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 63: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 64: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 65: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 66: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 67: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 68: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 69: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 70: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 71: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 72: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 73: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
50, // 74: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 75: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 76: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 77: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 78: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 79: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 80: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 81: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 82: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
57, // [57:83] is the sub-list for method output_type
31, // [31:57] is the sub-list for method input_type
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
5, // 58: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 59: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 60: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 61: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 62: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 63: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 64: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 65: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 66: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 67: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 68: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 69: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 70: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 71: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 72: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 73: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 74: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
50, // 75: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 76: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 77: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 78: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 79: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 80: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 81: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 82: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 83: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
69, // 84: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
58, // [58:85] is the sub-list for method output_type
31, // [31:58] is the sub-list for method input_type
31, // [31:31] is the sub-list for extension type_name
31, // [31:31] is the sub-list for extension extendee
0, // [0:31] is the sub-list for field type_name
@@ -5014,13 +5113,14 @@ func file_daemon_proto_init() {
file_daemon_proto_msgTypes[46].OneofWrappers = []any{}
file_daemon_proto_msgTypes[52].OneofWrappers = []any{}
file_daemon_proto_msgTypes[54].OneofWrappers = []any{}
file_daemon_proto_msgTypes[65].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3,
NumMessages: 68,
NumMessages: 70,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -79,6 +79,9 @@ service DaemonService {
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
rpc GetActiveProfile(GetActiveProfileRequest) returns (GetActiveProfileResponse) {}
// Logout disconnects from the network and deletes the peer from the management server
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
}
@@ -614,4 +617,11 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
}
}
message LogoutRequest {
optional string profileName = 1;
optional string username = 2;
}
message LogoutResponse {}

View File

@@ -61,6 +61,8 @@ type DaemonServiceClient interface {
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
}
type daemonServiceClient struct {
@@ -328,6 +330,15 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv
return out, nil
}
func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) {
out := new(LogoutResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Logout", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -375,6 +386,8 @@ type DaemonServiceServer interface {
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -460,6 +473,9 @@ func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfi
func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetActiveProfile not implemented")
}
func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -944,6 +960,24 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex
return interceptor(ctx, in, info, handler)
}
func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LogoutRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).Logout(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/Logout",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Logout(ctx, req.(*LogoutRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1051,6 +1085,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetActiveProfile",
Handler: _DaemonService_GetActiveProfile_Handler,
},
{
MethodName: "Logout",
Handler: _DaemonService_Logout_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -2,6 +2,7 @@ package server
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
@@ -13,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
@@ -24,6 +26,7 @@ import (
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/client/internal"
@@ -47,6 +50,8 @@ const (
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
)
var ErrServiceNotUp = errors.New("service is not up")
// Server for service control.
type Server struct {
rootCtx context.Context
@@ -131,13 +136,7 @@ func (s *Server) Start() error {
return fmt.Errorf("failed to get active profile state: %w", err)
}
cfgPath, err := activeProf.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
@@ -484,13 +483,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
s.mutex.Unlock()
cfgPath, err := activeProf.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -701,13 +694,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
cfgPath, err := activeProf.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -789,13 +776,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
cfgPath, err := activeProf.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get default profile config: %v", err)
return nil, fmt.Errorf("failed to get default profile config: %w", err)
@@ -811,26 +792,201 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock()
defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil {
return nil, fmt.Errorf("service is not up")
}
s.actCancel()
err := s.connectClient.Stop()
if err != nil {
if err := s.cleanupConnection(); err != nil {
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
s.isSessionActive.Store(false)
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil
}
func (s *Server) cleanupConnection() error {
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil {
return ErrServiceNotUp
}
s.actCancel()
if s.connectClient == nil {
return nil
}
if err := s.connectClient.Stop(); err != nil {
return err
}
s.connectClient = nil
s.isSessionActive.Store(false)
log.Infof("service is down")
return &proto.DownResponse{}, nil
return nil
}
func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.ProfileName != nil && *msg.ProfileName != "" {
return s.handleProfileLogout(ctx, msg)
}
return s.handleActiveProfileLogout(ctx)
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
}
return &proto.LogoutResponse{}, nil
}
func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutResponse, error) {
if s.config == nil {
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err)
}
config, err := s.getConfig(activeProf)
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in")
}
s.config = config
}
if err := s.sendLogoutRequest(ctx); err != nil {
log.Errorf("failed to send logout request: %v", err)
return nil, err
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
return &proto.LogoutResponse{}, nil
}
// getConfig loads the config from the active profile
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
if err != nil {
return nil, fmt.Errorf("failed to get config: %w", err)
}
return config, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
}
return nil
}
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if profileName == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(profileName); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
return nil
}
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
if err != nil {
return fmt.Errorf("get profile path: %w", err)
}
config, err := profilemanager.GetConfig(profilePath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profileName)
}
return s.sendLogoutRequestWithConfig(ctx, config)
}
func (s *Server) sendLogoutRequest(ctx context.Context) error {
return s.sendLogoutRequestWithConfig(ctx, s.config)
}
func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profilemanager.Config) error {
key, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
return fmt.Errorf("parse private key: %w", err)
}
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled)
if err != nil {
return fmt.Errorf("connect to management server: %w", err)
}
defer func() {
if err := mgmClient.Close(); err != nil {
log.Errorf("close management client: %v", err)
}
}()
return mgmClient.Logout()
}
// Status returns the daemon status
@@ -1107,12 +1263,12 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {

View File

@@ -831,6 +831,7 @@ func (s *serviceClient) onTrayReady() {
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mSettings.AddSeparator()
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
s.loadSettings()

View File

@@ -13,6 +13,7 @@ import (
"fyne.io/systray"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
@@ -231,3 +232,19 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
log.Printf("command '%s %s' completed successfully", command, arg)
}
func (h *eventHandler) logout(ctx context.Context) error {
client, err := h.client.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf("failed to get service client: %w", err)
}
_, err = client.Logout(ctx, &proto.LogoutRequest{})
if err != nil {
return fmt.Errorf("logout failed: %w", err)
}
h.client.getSrvConfig()
return nil
}

View File

@@ -40,12 +40,13 @@ func (s *serviceClient) showProfilesUI() {
list := widget.NewList(
func() int { return len(profiles) },
func() fyne.CanvasObject {
// Each item: Selected indicator, Name, spacer, Select & Remove buttons
// Each item: Selected indicator, Name, spacer, Select, Logout & Remove buttons
return container.NewHBox(
widget.NewLabel(""), // indicator
widget.NewLabel(""), // profile name
layout.NewSpacer(),
widget.NewButton("Select", nil),
widget.NewButton("Logout", nil),
widget.NewButton("Remove", nil),
)
},
@@ -55,7 +56,8 @@ func (s *serviceClient) showProfilesUI() {
indicator := row.Objects[0].(*widget.Label)
nameLabel := row.Objects[1].(*widget.Label)
selectBtn := row.Objects[3].(*widget.Button)
removeBtn := row.Objects[4].(*widget.Button)
logoutBtn := row.Objects[4].(*widget.Button)
removeBtn := row.Objects[5].(*widget.Button)
profile := profiles[i]
// Show a checkmark if selected
@@ -105,7 +107,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
status, err := conn.Status(context.Background(), &proto.StatusRequest{})
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("failed to get status after switching profile: %v", err)
return
@@ -125,6 +127,12 @@ func (s *serviceClient) showProfilesUI() {
)
}
logoutBtn.Show()
logoutBtn.SetText("Logout")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh)
}
// Remove profile
removeBtn.SetText("Remove")
removeBtn.OnTapped = func() {
@@ -135,7 +143,7 @@ func (s *serviceClient) showProfilesUI() {
if !confirm {
return
}
// remove
err = s.removeProfile(profile.Name)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
@@ -230,7 +238,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
_, err = conn.AddProfile(context.Background(), &proto.AddProfileRequest{
_, err = conn.AddProfile(s.ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
@@ -253,7 +261,7 @@ func (s *serviceClient) switchProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
if _, err := conn.SwitchProfile(context.Background(), &proto.SwitchProfileRequest{
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &currUser.Username,
}); err != nil {
@@ -279,7 +287,7 @@ func (s *serviceClient) removeProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
_, err = conn.RemoveProfile(context.Background(), &proto.RemoveProfileRequest{
_, err = conn.RemoveProfile(s.ctx, &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
@@ -305,7 +313,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
if err != nil {
return nil, fmt.Errorf("get current user: %w", err)
}
profilesResp, err := conn.ListProfiles(context.Background(), &proto.ListProfilesRequest{
profilesResp, err := conn.ListProfiles(s.ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
@@ -324,6 +332,52 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm(
"Logout",
fmt.Sprintf("Are you sure you want to logout from '%s'?", profileName),
func(confirm bool) {
if !confirm {
return
}
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("failed to get service client: %v", err)
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
return
}
currUser, err := user.Current()
if err != nil {
log.Errorf("failed to get current user: %v", err)
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
return
}
username := currUser.Username
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName,
Username: &username,
})
if err != nil {
log.Errorf("logout failed: %v", err)
dialog.ShowError(fmt.Errorf("logout failed"), s.wProfiles)
return
}
dialog.ShowInformation(
"Logged Out",
fmt.Sprintf("Successfully logged out from '%s'", profileName),
s.wProfiles,
)
refreshCallback()
},
s.wProfiles,
)
}
type subItem struct {
*systray.MenuItem
ctx context.Context
@@ -339,6 +393,7 @@ type profileMenu struct {
emailMenuItem *systray.MenuItem
profileSubItems []*subItem
manageProfilesSubItem *subItem
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
upClickCallback func() error
@@ -533,12 +588,11 @@ func (p *profileMenu) refresh() {
for {
select {
case <-ctx.Done():
return // context cancelled
return
case _, ok := <-manageItem.ClickedCh:
if !ok {
return // channel closed
return
}
// Handle manage profiles click
p.eventHandler.runSelfCommand(p.ctx, "profiles", "true")
p.refresh()
p.loadSettingsCallback()
@@ -546,6 +600,30 @@ func (p *profileMenu) refresh() {
}
}()
// Add Logout menu item
ctx2, cancel2 := context.WithCancel(context.Background())
logoutItem := p.profileMenuItem.AddSubMenuItem("Logout", "")
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
go func() {
for {
select {
case <-ctx2.Done():
return
case _, ok := <-logoutItem.ClickedCh:
if !ok {
return
}
if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err)
p.app.SendNotification(fyne.NewNotification("Error", "Failed to logout"))
} else {
p.app.SendNotification(fyne.NewNotification("Success", "Logged out successfully"))
}
}
}
}()
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
p.profileMenuItem.SetTitle(activeProf.ProfileName)
} else {
@@ -556,7 +634,6 @@ func (p *profileMenu) refresh() {
}
func (p *profileMenu) clear(profiles []Profile) {
// Clear existing profile items
for _, item := range p.profileSubItems {
item.Remove()
item.cancel()
@@ -565,11 +642,16 @@ func (p *profileMenu) clear(profiles []Profile) {
p.profilesState = profiles
if p.manageProfilesSubItem != nil {
// Remove the manage profiles item if it exists
p.manageProfilesSubItem.Remove()
p.manageProfilesSubItem.cancel()
p.manageProfilesSubItem = nil
}
if p.logoutSubItem != nil {
p.logoutSubItem.Remove()
p.logoutSubItem.cancel()
p.logoutSubItem = nil
}
}
func (p *profileMenu) updateMenu() {

View File

@@ -1,4 +1,4 @@
FROM ubuntu:24.04
FROM ubuntu:24.10
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]

View File

@@ -22,4 +22,5 @@ type Client interface {
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
Logout() error
}

View File

@@ -497,6 +497,32 @@ func (c *GrpcClient) notifyConnected() {
c.connStateCallback.MarkManagementConnected()
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5)
defer cancel()
message := &proto.Empty{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return fmt.Errorf("encrypt logout message: %w", err)
}
_, err = c.realClient.Logout(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return fmt.Errorf("logout: %w", err)
}
return nil
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil

View File

@@ -19,6 +19,7 @@ type MockClient struct {
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
}
func (m *MockClient) IsHealthy() bool {
@@ -85,3 +86,10 @@ func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
}
return m.SyncMetaFunc(sysInfo)
}
func (m *MockClient) Logout() error {
if m.LogoutFunc == nil {
return nil
}
return m.LogoutFunc()
}

View File

@@ -142,7 +142,7 @@ var (
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
return fmt.Errorf("migrate files %v", err)
}
if _, err = os.Stat(config.Datadir); os.IsNotExist(err) {
@@ -184,7 +184,7 @@ var (
}
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
return fmt.Errorf("failed to initialize database: %s", err)
return fmt.Errorf("initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
@@ -192,7 +192,7 @@ var (
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
return fmt.Errorf("write out store encryption key: %s", err)
}
}
@@ -205,7 +205,7 @@ var (
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
return fmt.Errorf("initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
@@ -217,7 +217,7 @@ var (
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
return fmt.Errorf("build default manager: %v", err)
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)

View File

@@ -3825,7 +3825,7 @@ var file_management_proto_rawDesc = []byte{
0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54,
0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e,
0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04,
0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67,
0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05,
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
@@ -3858,8 +3858,12 @@ var file_management_proto_rawDesc = []byte{
0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67,
0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45,
0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -3986,15 +3990,17 @@ var file_management_proto_depIdxs = []int32{
5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
5, // 60: management.ManagementService.Login:output_type -> management.EncryptedMessage
5, // 61: management.ManagementService.Sync:output_type -> management.EncryptedMessage
16, // 62: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
17, // 63: management.ManagementService.isHealthy:output_type -> management.Empty
5, // 64: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
5, // 65: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
17, // 66: management.ManagementService.SyncMeta:output_type -> management.Empty
60, // [60:67] is the sub-list for method output_type
53, // [53:60] is the sub-list for method input_type
5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage
5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage
5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage
16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty
5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty
17, // 68: management.ManagementService.Logout:output_type -> management.Empty
61, // [61:69] is the sub-list for method output_type
53, // [53:61] is the sub-list for method input_type
53, // [53:53] is the sub-list for extension type_name
53, // [53:53] is the sub-list for extension extendee
0, // [0:53] is the sub-list for field type_name

View File

@@ -45,6 +45,9 @@ service ManagementService {
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
// Logout logs out the peer and removes it from the management server
rpc Logout(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {

View File

@@ -48,6 +48,8 @@ type ManagementServiceClient interface {
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
@@ -144,6 +146,15 @@ func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMes
return out, nil
}
func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/Logout", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
@@ -178,6 +189,8 @@ type ManagementServiceServer interface {
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
@@ -206,6 +219,9 @@ func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Con
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -348,6 +364,24 @@ func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).Logout(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/Logout",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).Logout(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -379,6 +413,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
{
MethodName: "Logout",
Handler: _ManagementService_Logout_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"reflect"
"regexp"
@@ -324,6 +325,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
updateAccountPeers = true
}
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
@@ -362,7 +370,18 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
eventMeta := map[string]any{
"old_dns_domain": oldSettings.DNSDomain,
"new_dns_domain": newSettings.DNSDomain,
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
eventMeta := map[string]any{
"old_network_range": oldSettings.NetworkRange.String(),
"new_network_range": newSettings.NetworkRange.String(),
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
@@ -1368,7 +1387,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
@@ -1382,28 +1401,22 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
for _, peer := range peers {
for _, g := range addNewGroups {
if err := transaction.AddPeerToGroup(ctx, userAuth.AccountId, peer.ID, g); err != nil {
return fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, g, err)
}
}
for _, g := range removeOldGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, g); err != nil {
return fmt.Errorf("error removing peer %s from group %s: %w", peer.ID, g, err)
}
}
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
@@ -1971,53 +1984,207 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsToUpdate := make(map[string]*types.Group)
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, fmt.Errorf("error getting account group peers: %w", err)
}
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, fmt.Errorf("error getting account groups: %w", err)
}
for _, group := range accountGroups {
if _, exists := accountGroupPeers[group.ID]; !exists {
accountGroupPeers[group.ID] = make(map[string]struct{})
}
}
updatedGroups := []string{}
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
if err != nil {
return false, false, err
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
if err != nil {
return false, false, err
}
for _, group := range updatedGroups {
groupsToUpdate[group.ID] = group
groupsMap[group.ID] = group
for _, peer := range userPeers {
for _, groupID := range user.AutoGroups {
if _, exists := accountGroupPeers[groupID]; !exists {
// we do not wanna create the groups here
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
continue
}
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
continue
}
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
}
updatedGroups = append(updatedGroups, groupID)
}
}
}
if len(groupsToUpdate) == 0 {
return false, false, nil
}
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups)
if err != nil {
return false, false, err
return false, false, fmt.Errorf("error checking if group changes affect peers: %w", err)
}
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
if err != nil {
return false, false, err
}
return true, peersAffected, nil
return len(updatedGroups) > 0, peersAffected, nil
}
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
if !newNetworkRange.IsValid() {
return nil
}
newIPNet := net.IPNet{
IP: newNetworkRange.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
}
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return err
}
account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return err
}
var takenIPs []net.IP
for _, peer := range peers {
newIP, err := types.AllocatePeerIP(newIPNet, takenIPs)
if err != nil {
return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err)
}
log.WithContext(ctx).Infof("reallocating peer %s IP from %s to %s due to network range change",
peer.ID, peer.IP.String(), newIP.String())
peer.IP = newIP
takenIPs = append(takenIPs, newIP)
}
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
}
}
log.WithContext(ctx).Infof("successfully re-allocated IPs for %d peers in account %s to network range %s",
len(peers), accountID, newNetworkRange.String())
return nil
}
func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error {
if !account.Network.Net.Contains(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String())
}
for _, peer := range peers {
if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) {
return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID)
}
}
return nil
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err)
}
if updateNetworkMap {
am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) (bool, error) {
var updateNetworkMap bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
account, err := transaction.GetAccount(ctx, accountID)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
if existingPeer.IP.Equal(newIP.AsSlice()) {
return nil
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return fmt.Errorf("get account peers: %w", err)
}
if err := am.validateIPForUpdate(account, peers, peerID, newIP); err != nil {
return err
}
if err := am.savePeerIPUpdate(ctx, transaction, accountID, userID, existingPeer, newIP); err != nil {
return err
}
updateNetworkMap = true
return nil
})
return updateNetworkMap, err
}
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
dnsDomain := am.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()
peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer)
if err != nil {
return fmt.Errorf("save peer: %w", err)
}
eventMeta["old_ip"] = oldIP
eventMeta["ip"] = newIP.String()
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerIPUpdated, eventMeta)
return nil
}

View File

@@ -51,6 +51,7 @@ type Manager interface {
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
@@ -62,8 +63,10 @@ type Manager interface {
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os"
"reflect"
"strconv"
@@ -1159,7 +1160,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Name: "GroupA",
Peers: []string{},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1194,7 +1195,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}()
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1240,11 +1241,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
AccountID: account.Id,
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1292,7 +1294,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1343,11 +1345,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
require.NoError(t, err, "failed to save group")
@@ -1672,9 +1674,10 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
},
Policies: []*types.Policy{
@@ -2616,6 +2619,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -3360,7 +3364,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
@@ -3382,7 +3386,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
@@ -3519,3 +3523,70 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
require.NoError(t, err)
})
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
require.NoError(t, err, "unable to add peer2")
t.Run("update peer IP successfully", func(t *testing.T) {
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get account")
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
require.NoError(t, err, "unable to allocate new IP")
newAddr := netip.MustParseAddr(newIP.String())
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
require.NoError(t, err, "unable to update peer IP")
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
require.NoError(t, err, "unable to get updated peer")
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
})
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
currentAddr := netip.MustParseAddr(peer1.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
require.NoError(t, err, "updating with same IP should not error")
})
t.Run("update peer IP with collision should fail", func(t *testing.T) {
peer2Addr := netip.MustParseAddr(peer2.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
require.Error(t, err, "should fail when IP is already assigned")
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
})
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
invalidAddr := netip.MustParseAddr("192.168.1.100")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
require.Error(t, err, "should fail when IP is outside network range")
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
})
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
newAddr := netip.MustParseAddr("100.64.0.101")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
require.Error(t, err, "should fail with invalid peer ID")
})
}

View File

@@ -175,6 +175,9 @@ const (
AccountLazyConnectionEnabled Activity = 85
AccountLazyConnectionDisabled Activity = 86
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
AccountDeleted Activity = 99999
)
@@ -277,6 +280,10 @@ var activityMap = map[Activity]Code{
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
}
// StringCode returns a string code of the activity

View File

@@ -495,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -506,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Name: "GroupB",
Peers: []string{},
},
}, true)
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -562,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
// Creating DNS settings with groups that have peers should update account peers and send peer update
t.Run("creating dns setting with used groups", func(t *testing.T) {
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
done := make(chan struct{})

View File

@@ -43,10 +43,10 @@ func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, us
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
delete(a.store.account.Peers, peerID)
if a.wg != nil {
a.wg.Done()
}
delete(a.store.account.Peers, peerID)
return nil
}

View File

@@ -65,22 +65,144 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error {
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
for _, peerID := range newGroup.Peers {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// SaveGroups adds new groups to the account.
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
}
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
}
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// CreateGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
operation := operations.Create
if !create {
operation = operations.Update
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation)
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -116,7 +238,65 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return err
}
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// UpdateGroups updates groups in the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
})
if err != nil {
return err
@@ -265,20 +445,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -288,7 +458,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
})
if err != nil {
return err
@@ -329,7 +499,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err
@@ -347,20 +517,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -370,7 +530,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
})
if err != nil {
return err
@@ -411,7 +571,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err

View File

@@ -2,14 +2,20 @@ package server
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/groups"
@@ -18,8 +24,10 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = types.GroupIssuedIntegration
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
group.ID = uuid.New().String()
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration)
}
@@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedJWT
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
group.ID = uuid.New().String()
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedJWT)
}
@@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true)
err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
@@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
return nil, nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
@@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Name: "GroupE",
Peers: []string{peer2.ID},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
@@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupE",
Name: "GroupE",
Peers: []string{peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
acc, err := createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
}

View File

@@ -19,7 +19,9 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
@@ -909,6 +911,44 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
return &proto.Empty{}, nil
}
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
empty := &proto.Empty{}
peerKey, err := s.parseRequest(ctx, req, empty)
if err != nil {
return nil, err
}
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String())
if err != nil {
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
// TODO: consider idempotency
return nil, mapError(ctx, err)
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.ID)
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, peer.AccountID)
userID := peer.UserID
if userID == "" {
userID = activity.SystemInitiator
}
if err = s.accountManager.DeletePeer(ctx, peer.AccountID, peer.ID, userID); err != nil {
log.WithContext(ctx).Errorf("failed to logout peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
return &proto.Empty{}, nil
}
// toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))

View File

@@ -133,6 +133,11 @@ components:
description: Allows to define a custom dns domain for the account
type: string
example: my-organization.org
network_range:
description: Allows to define a custom network range for the account in CIDR format
type: string
format: cidr
example: 100.64.0.0/16
extra:
$ref: '#/components/schemas/AccountExtraSettings'
lazy_connection_enabled:
@@ -342,6 +347,11 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
ip:
description: Peer's IP address
type: string
format: ipv4
example: 100.64.0.15
required:
- name
- ssh_enabled

View File

@@ -303,6 +303,9 @@ type AccountSettings struct {
// LazyConnectionEnabled Enables or disables experimental lazy connection
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// NetworkRange Allows to define a custom network range for the account in CIDR format
NetworkRange *string `json:"network_range,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
@@ -1196,11 +1199,14 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest.
type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
ApprovalRequired *bool `json:"approval_required,omitempty"`
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address
Ip *string `json:"ip,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
}
// PersonalAccessToken defines model for PersonalAccessToken.

View File

@@ -1,8 +1,10 @@
package accounts
import (
"context"
"encoding/json"
"net/http"
"net/netip"
"time"
"github.com/gorilla/mux"
@@ -16,6 +18,17 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
const (
// PeerBufferPercentage is the percentage of peers to add as buffer for network range calculations
PeerBufferPercentage = 0.5
// MinRequiredAddresses is the minimum number of addresses required in a network range
MinRequiredAddresses = 10
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
MinNetworkBitsIPv6 = 120
)
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
@@ -37,6 +50,86 @@ func newHandler(accountManager account.Manager, settingsManager settings.Manager
}
}
func validateIPAddress(addr netip.Addr) error {
if addr.IsLoopback() {
return status.Errorf(status.InvalidArgument, "loopback address range not allowed")
}
if addr.IsMulticast() {
return status.Errorf(status.InvalidArgument, "multicast address range not allowed")
}
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
return status.Errorf(status.InvalidArgument, "link-local address range not allowed")
}
return nil
}
func validateMinimumSize(prefix netip.Prefix) error {
addr := prefix.Addr()
if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4)
}
if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 {
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6)
}
return nil
}
func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error {
if !networkRange.IsValid() {
return nil
}
if err := validateIPAddress(networkRange.Addr()); err != nil {
return err
}
if err := validateMinimumSize(networkRange); err != nil {
return err
}
return h.validateCapacity(ctx, accountID, userID, networkRange)
}
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err)
}
maxHosts := calculateMaxHosts(prefix)
requiredAddresses := calculateRequiredAddresses(len(peers))
if maxHosts < requiredAddresses {
return status.Errorf(status.InvalidArgument,
"network range too small: need at least %d addresses for %d peers + buffer, but range provides %d",
requiredAddresses, len(peers), maxHosts)
}
return nil
}
func calculateMaxHosts(prefix netip.Prefix) int64 {
availableAddresses := prefix.Addr().BitLen() - prefix.Bits()
maxHosts := int64(1) << availableAddresses
if prefix.Addr().Is4() {
maxHosts -= 2 // network and broadcast addresses
}
return maxHosts
}
func calculateRequiredAddresses(peerCount int) int64 {
requiredAddresses := int64(peerCount) + int64(float64(peerCount)*PeerBufferPercentage)
if requiredAddresses < MinRequiredAddresses {
requiredAddresses = MinRequiredAddresses
}
return requiredAddresses
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -131,6 +224,18 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.LazyConnectionEnabled != nil {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings.NetworkRange = prefix
}
var onboarding *types.AccountOnboarding
if req.Onboarding != nil {
@@ -208,6 +313,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
DnsDomain: &settings.DNSDomain,
}
if settings.NetworkRange.IsValid() {
networkRangeStr := settings.NetworkRange.String()
apiSettings.NetworkRange = &networkRangeStr
}
apiOnboarding := api.AccountOnboarding{
OnboardingFlowPending: onboarding.OnboardingFlowPending,
SignupFormPending: onboarding.SignupFormPending,

View File

@@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil {
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
@@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true)
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
@@ -111,6 +112,19 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
}
}
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil {
util.WriteError(ctx, err, w)

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -112,6 +114,15 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
p.Name = update.Name
return p, nil
},
UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
for _, peer := range peers {
if peer.ID == peerID {
peer.IP = net.IP(newIP.AsSlice())
return nil
}
}
return fmt.Errorf("peer not found")
},
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
@@ -450,3 +461,73 @@ func TestGetAccessiblePeers(t *testing.T) {
})
}
}
func TestPeersHandlerUpdatePeerIP(t *testing.T) {
testPeer := &nbpeer.Peer{
ID: testPeerID,
Key: "key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
Name: "test-host@netbird.io",
LoginExpirationEnabled: false,
UserID: regularUser,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host@netbird.io",
Core: "22.04",
},
}
p := initTestMetaData(testPeer)
tt := []struct {
name string
peerID string
requestBody string
callerUserID string
expectedStatus int
expectedIP string
}{
{
name: "update peer IP successfully",
peerID: testPeerID,
requestBody: `{"ip": "100.64.0.100"}`,
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedIP: "100.64.0.100",
},
{
name: "update peer IP with invalid IP",
peerID: testPeerID,
requestBody: `{"ip": "invalid-ip"}`,
callerUserID: adminUser,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
})
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK && tc.expectedIP != "" {
var updatedPeer api.Peer
err := json.Unmarshal(rr.Body.Bytes(), &updatedPeer)
require.NoError(t, err)
assert.Equal(t, tc.expectedIP, updatedPeer.Ip)
}
})
}
}

View File

@@ -39,6 +39,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
return nil
}
if !db.Migrator().HasColumn(&model, fieldName) {
log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
@@ -422,3 +427,62 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
return nil
}
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&model, columnName) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" {
continue
}
var data []string
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
return fmt.Errorf("unmarshal json: %w", err)
}
for _, value := range data {
if err := tx.Create(
mapperFunc(row["account_id"].(string), row["id"].(string), value),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
}
}
}
if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
return fmt.Errorf("drop column %s: %w", columnName, err)
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}

View File

@@ -60,6 +60,7 @@ type MockAccountManager struct {
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
@@ -124,6 +125,34 @@ type MockAccountManager struct {
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, true)
}
return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented")
}
func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, false)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented")
}
func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true)
}
return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented")
}
func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
if am.UpdateAccountPeersFunc != nil {
am.UpdateAccountPeersFunc(ctx, accountID)
@@ -455,6 +484,13 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
}
func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
if am.UpdatePeerIPFunc != nil {
return am.UpdatePeerIPFunc(ctx, accountID, userID, peerID, newIP)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented")
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {

View File

@@ -980,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
}, true)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
})
assert.NoError(t, err)
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -374,12 +374,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
@@ -478,7 +486,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var updateAccountPeers bool
var setupKeyID string
var setupKeyName string
@@ -615,20 +622,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return err
}
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -678,7 +685,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
@@ -1021,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1518,22 +1525,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
groupIDs := make([]string, 0, len(groups))
for _, group := range groups {
groupIDs = append(groupIDs, group.ID)
}
return groupIDs, err
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
@@ -1563,17 +1560,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
}
for _, peer := range peers {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID)
if err != nil {
return nil, fmt.Errorf("failed to get peer groups: %w", err)
}
for _, group := range groups {
group.RemovePeer(peer.ID)
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
if err != nil {
return nil, fmt.Errorf("failed to save group: %w", err)
}
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID)
}
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil {

View File

@@ -310,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true)
err = manager.CreateGroup(context.Background(), account.Id, userID, &group1)
if err != nil {
t.Errorf("expecting group1 to be added, got failure %v", err)
return
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true)
err = manager.CreateGroup(context.Background(), account.Id, userID, &group2)
if err != nil {
t.Errorf("expecting group2 to be added, got failure %v", err)
return
@@ -1475,6 +1475,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
}
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1709,7 +1713,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1725,8 +1729,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
require.NoError(t, err)
}
for _, group := range g {
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
require.NoError(t, err)
}
// create a user with auto groups
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
@@ -1785,7 +1792,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -2164,7 +2171,6 @@ func Test_IsUniqueConstraintError(t *testing.T) {
}
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -2176,7 +2182,7 @@ func Test_AddPeer(t *testing.T) {
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
t.Fatalf("error creating account: %v", err)
return
}
@@ -2186,22 +2192,21 @@ func Test_AddPeer(t *testing.T) {
return
}
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
const differentHostnames = 50
const totalPeers = 300
var wg sync.WaitGroup
errs := make(chan error, totalPeers+differentHostnames)
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
hostNameID := i % differentHostnames
go func(i int) {
defer wg.Done()
newPeer := &nbpeer.Peer{
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
}
<-start

View File

@@ -993,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1014,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -1025,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
var policyWithGroupRulesNoPeers *types.Policy
var policyWithDestinationPeersOnly *types.Policy
var policyWithSourceAndDestinationPeers *types.Policy
var err error
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/posture"
@@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
Id: regularUserID,
Role: types.UserRoleUser,
}
peer1 := &peer.Peer{
ID: "peer1",
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
account.Peers["peer1"] = peer1
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
@@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
require.NoError(t, err)
postureCheckB := &posture.Checks{
@@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
AccountID: account.Id,
Peers: []string{"peer1"},
}
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to create groupA")
groupB := &types.Group{
ID: "groupB",
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB)
require.NoError(t, err, "failed to create groupB")
postureCheckA := &posture.Checks{
Name: "checkA",
@@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)

View File

@@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
Name: "peer1 group",
Peers: []string{peer1ID},
}
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true)
err = am.CreateGroup(context.Background(), account.Id, userID, newGroup)
require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
@@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
}
for _, group := range newGroup {
err = am.SaveGroup(context.Background(), accountID, userID, group, true)
err = am.CreateGroup(context.Background(), accountID, userID, group)
if err != nil {
return nil, err
}
@@ -1953,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
account, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1969,8 +1969,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
require.NoError(t, err, "failed to create group %s", group.Name)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
@@ -2149,11 +2152,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -2189,11 +2192,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1ID},
}, true)
})
assert.NoError(t, err)
select {

View File

@@ -29,7 +29,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
err = manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "group_1",
Name: "group_name_1",
@@ -40,7 +40,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
Name: "group_name_2",
Peers: []string{},
},
}, true)
})
if err != nil {
t.Fatal(err)
}
@@ -104,20 +104,20 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
}, true)
})
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
}, true)
})
if err != nil {
t.Fatal(err)
}
@@ -398,11 +398,11 @@ func TestSetupKey_Copy(t *testing.T) {
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
policy := &types.Policy{

View File

@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
for _, group := range account.GroupsG {
group.StoreGroupPeers()
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
@@ -247,7 +251,8 @@ func generateAccountSQLTypes(account *types.Account) {
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
group.AccountID = account.Id
account.GroupsG = append(account.GroupsG, group)
}
for id, route := range account.Routes {
@@ -449,25 +454,56 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
// CreateGroups creates the given list of groups to the database.
func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Omit(clause.Associations).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
return nil
})
}
// UpdateGroups updates the given list of groups to the database.
func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
if len(groups) == 0 {
return nil
}
return nil
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Omit(clause.Associations).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
return nil
})
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -646,7 +682,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -655,6 +691,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -669,6 +709,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
likePattern := `%"ID":"` + resourceID + `"%`
result := tx.
Preload(clause.Associations).
Where("resources LIKE ?", likePattern).
Find(&groups)
@@ -679,6 +720,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return nil, result.Error
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -765,6 +810,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
@@ -814,6 +860,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
@@ -1311,55 +1368,76 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var groupID string
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
Scan(&groupID)
if groupID == "" {
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(&types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
PeerID: peerID,
}).Error
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
if err != nil {
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
}
return nil
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
PeerID: peerID,
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add peer to group")
}
group.Peers = append(group.Peers, peerId)
return nil
}
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to remove peer from group")
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
return status.Errorf(status.Internal, "failed to remove peer from all groups")
}
return nil
@@ -1427,15 +1505,46 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
Where("group_peers.peer_id = ?", peerId).
Preload(clause.Associations).
Find(&groups)
if query.Error != nil {
return nil, query.Error
}
for _, group := range groups {
group.LoadGroupPeers()
}
return groups, nil
}
// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groupIDs []string
query := tx.
Model(&types.GroupPeer{}).
Where("account_id = ? AND peer_id = ?", accountId, peerId).
Pluck("group_id", &groupIDs)
if query.Error != nil {
if errors.Is(query.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId)
}
log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error)
return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store")
}
return groupIDs, nil
}
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
@@ -1485,7 +1594,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1722,7 +1831,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
@@ -1731,15 +1840,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
group.LoadGroupPeers()
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group
@@ -1747,16 +1855,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// we may need to reconsider changing the types.
query := tx.Preload(clause.Associations)
switch s.storeEngine {
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
result := query.
Model(&types.Group{}).
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
Group("groups.id").
Order("COUNT(group_peers.peer_id) DESC").
Limit(1).
First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupName)
@@ -1764,6 +1870,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
group.LoadGroupPeers()
return &group, nil
}
@@ -1775,7 +1884,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
}
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
@@ -1783,25 +1892,45 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
groupsMap[group.ID] = group
}
return groupsMap, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
// CreateGroup creates a group in the store.
func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// UpdateGroup updates a group in the store.
func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1818,6 +1947,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
@@ -2613,3 +2743,27 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil
}
func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []types.GroupPeer
result := tx.Find(&peers, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account group peers from store")
}
groupPeers := make(map[string]map[string]struct{})
for _, peer := range peers {
if _, exists := groupPeers[peer.GroupID]; !exists {
groupPeers[peer.GroupID] = make(map[string]struct{})
}
groupPeers[peer.GroupID][peer.PeerID] = struct{}{}
}
return groupPeers, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/binary"
"fmt"
"math/rand"
"net"
@@ -1187,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
@@ -1348,7 +1349,8 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) {
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
func TestSqlStore_CreateGroup(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.MysqlStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1356,12 +1358,14 @@ func TestSqlStore_SaveGroup(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &types.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
@@ -1369,7 +1373,7 @@ func TestSqlStore_SaveGroup(t *testing.T) {
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
func TestSqlStore_CreateUpdateGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1378,23 +1382,27 @@ func TestSqlStore_SaveGroups(t *testing.T) {
groups := []*types.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err)
groups[1].Peers = []string{}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
@@ -2523,7 +2531,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
require.NoError(t, err, "failed to get group")
require.Len(t, group.Peers, 0, "group should have 0 peers")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID)
require.NoError(t, err, "failed to add peer to group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2554,7 +2562,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
require.NoError(t, err, "failed to add peer to account")
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
require.NoError(t, err, "failed to add peer to all group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2640,7 +2648,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
assert.Len(t, groups, 1)
assert.Equal(t, groups[0].Name, "All")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h")
require.NoError(t, err)
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)
@@ -3307,7 +3315,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
})
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
require.NoError(t, err)
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
@@ -3538,3 +3546,64 @@ func TestSqlStore_GetAnyAccountID(t *testing.T) {
assert.Empty(t, accountID)
})
}
func BenchmarkGetAccountPeers(b *testing.B) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", b.TempDir())
if err != nil {
b.Fatal(err)
}
b.Cleanup(cleanup)
numberOfPeers := 1000
numberOfGroups := 200
numberOfPeersPerGroup := 500
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peers := make([]*nbpeer.Peer, 0, numberOfPeers)
for i := 0; i < numberOfPeers; i++ {
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
DNSLabel: fmt.Sprintf("peer%d.example.com", i),
IP: intToIPv4(uint32(i)),
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
if err != nil {
b.Fatalf("Failed to add peer: %v", err)
}
peers = append(peers, peer)
}
for i := 0; i < numberOfGroups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
ID: groupID,
AccountID: accountID,
}
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
b.Fatalf("Failed to create group: %v", err)
}
for j := 0; j < numberOfPeersPerGroup; j++ {
peerIndex := (i*numberOfPeersPerGroup + j) % numberOfPeers
err = store.AddPeerToGroup(context.Background(), accountID, peers[peerIndex].ID, groupID)
if err != nil {
b.Fatalf("Failed to add peer to group: %v", err)
}
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID)
if err != nil {
b.Fatal(err)
}
}
}
func intToIPv4(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}

View File

@@ -101,8 +101,10 @@ type Store interface {
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
@@ -120,9 +122,12 @@ type Store interface {
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
@@ -196,6 +201,7 @@ type Store interface {
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
}
const (
@@ -353,6 +359,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any {
return &types.GroupPeer{
AccountID: accountID,
GroupID: id,
PeerID: value,
}
})
},
}
}

View File

@@ -73,7 +73,7 @@ type Account struct {
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`

View File

@@ -53,7 +53,7 @@ type Config struct {
StoreConfig StoreConfig
ReverseProxy ReverseProxy
// disable default all-to-all policy
DisableDefaultPolicy bool
}

View File

@@ -26,7 +26,8 @@ type Group struct {
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -34,6 +35,32 @@ type Group struct {
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
type GroupPeer struct {
AccountID string `gorm:"index"`
GroupID string `gorm:"primaryKey"`
PeerID string `gorm:"primaryKey"`
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
g.Peers[i] = peer.PeerID
}
g.GroupPeers = []GroupPeer{}
}
func (g *Group) StoreGroupPeers() {
g.GroupPeers = make([]GroupPeer, len(g.Peers))
for i, peer := range g.Peers {
g.GroupPeers[i] = GroupPeer{
AccountID: g.AccountID,
GroupID: g.ID,
PeerID: peer,
}
}
g.Peers = []string{}
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -46,13 +73,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
AccountID: g.AccountID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -163,7 +163,10 @@ func (n *Network) Copy() *Network {
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
totalIPs := uint32(1 << SubnetSize)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
totalIPs := uint32(1 << hostBits)
taken := make(map[uint32]struct{}, len(takenIps)+1)
taken[baseIP] = struct{}{} // reserve network IP

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewNetwork(t *testing.T) {
@@ -38,6 +39,107 @@ func TestAllocatePeerIP(t *testing.T) {
}
}
func TestAllocatePeerIPSmallSubnet(t *testing.T) {
// Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30)
ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}}
var ips []net.IP
// Allocate all available IPs in the /27 network
for i := 0; i < 30; i++ {
ip, err := AllocatePeerIP(ipNet, ips)
if err != nil {
t.Fatal(err)
}
// Verify IP is within the correct range
if !ipNet.Contains(ip) {
t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String())
}
ips = append(ips, ip)
}
assert.Len(t, ips, 30)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
if _, ok := uniq[ip.String()]; !ok {
uniq[ip.String()] = struct{}{}
} else {
t.Errorf("found duplicate IP %s", ip.String())
}
}
// Try to allocate one more IP - should fail as network is full
_, err := AllocatePeerIP(ipNet, ips)
if err == nil {
t.Error("expected error when network is full, but got none")
}
}
func TestAllocatePeerIPVariousCIDRs(t *testing.T) {
testCases := []struct {
name string
cidr string
expectedUsable int
}{
{"/30 network", "192.168.1.0/30", 2}, // 4 total - 2 reserved = 2 usable
{"/29 network", "192.168.1.0/29", 6}, // 8 total - 2 reserved = 6 usable
{"/28 network", "192.168.1.0/28", 14}, // 16 total - 2 reserved = 14 usable
{"/27 network", "192.168.1.0/27", 30}, // 32 total - 2 reserved = 30 usable
{"/26 network", "192.168.1.0/26", 62}, // 64 total - 2 reserved = 62 usable
{"/25 network", "192.168.1.0/25", 126}, // 128 total - 2 reserved = 126 usable
{"/16 network", "10.0.0.0/16", 65534}, // 65536 total - 2 reserved = 65534 usable
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, ipNet, err := net.ParseCIDR(tc.cidr)
require.NoError(t, err)
var ips []net.IP
// For larger networks, test only a subset to avoid long test runs
testCount := tc.expectedUsable
if testCount > 1000 {
testCount = 1000
}
// Allocate IPs and verify they're within the correct range
for i := 0; i < testCount; i++ {
ip, err := AllocatePeerIP(*ipNet, ips)
require.NoError(t, err, "failed to allocate IP %d", i)
// Verify IP is within the correct range
assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String())
// Verify IP is not network or broadcast address
networkIP := ipNet.IP.Mask(ipNet.Mask)
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1
broadcastIP := uint32ToIP(broadcastInt)
assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String())
assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String())
ips = append(ips, ip)
}
assert.Len(t, ips, testCount)
// Verify all IPs are unique
uniq := make(map[string]struct{})
for _, ip := range ips {
ipStr := ip.String()
assert.NotContains(t, uniq, ipStr, "found duplicate IP %s", ipStr)
uniq[ipStr] = struct{}{}
}
})
}
}
func TestGenerateIPs(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})

View File

@@ -1,6 +1,7 @@
package types
import (
"net/netip"
"time"
)
@@ -42,6 +43,9 @@ type Settings struct {
// DNSDomain is the custom domain for that account
DNSDomain string
// NetworkRange is the custom network range for that account
NetworkRange netip.Prefix `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
@@ -66,6 +70,7 @@ func (s *Settings) Copy() *Settings {
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -35,7 +35,7 @@ type SetupKey struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
KeySecret string `gorm:"index"`
Name string
Type SetupKeyType
CreatedAt time.Time

View File

@@ -677,13 +677,18 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
if err != nil {
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil {
return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err)
addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
}
}
for _, groupID := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
}
}
}
}
@@ -1137,93 +1142,6 @@ func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID str
return userInfo, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
changed := false
for pid := range userPeerIDs {
if _, exists := groupPeers[pid]; !exists {
groupPeers[pid] = struct{}{}
changed = true
}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
if changed {
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
return changed
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
// skip removing peers from group All
if group.Name == "All" {
return false
}
updatedPeers := make([]string, 0, len(group.Peers))
changed := false
for _, pid := range group.Peers {
if _, owned := userPeerIDs[pid]; owned {
changed = true
continue
}
updatedPeers = append(updatedPeers, pid)
}
if changed {
group.Peers = updatedPeers
}
return changed
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {

View File

@@ -1335,11 +1335,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
require.NoError(t, err)
policy := &types.Policy{

View File

@@ -23,7 +23,6 @@ func FileExists(path string) bool {
return err == nil
}
/// Bool helpers
// True returns a *bool whose underlying value is true.
@@ -56,4 +55,4 @@ func ReturnBoolWithDefaultTrue(b *bool) bool {
return true
}
}
}

View File

@@ -6,7 +6,7 @@ import (
"time"
)
//Duration is used strictly for JSON requests/responses due to duration marshalling issues
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}