mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-10 18:59:55 +00:00
Compare commits
1 Commits
fix/dex-co
...
dex-nocgo-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e187a11a0 |
@@ -713,10 +713,8 @@ checksum:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
|
||||
@@ -85,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
@@ -98,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix := iface.Address().Network
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
@@ -931,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
|
||||
if blocked {
|
||||
pnum := getProtocolFromPacket(d)
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
@@ -996,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return false
|
||||
}
|
||||
|
||||
protoLayer := d.decoded[1]
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: proto,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
@@ -1040,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
return layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
return layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
if ipLayer == layers.LayerTypeIPv6 {
|
||||
return layers.LayerTypeICMPv6
|
||||
}
|
||||
return layers.LayerTypeICMPv4
|
||||
case firewall.ProtocolALL:
|
||||
return layerTypeAll
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return nftypes.TCP
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMP
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1238,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
return false
|
||||
}
|
||||
|
||||
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
@@ -1280,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
return sourceMatched
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
|
||||
@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -17,7 +16,7 @@ type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu.Load()
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *endpoint) Close() {
|
||||
// Endpoint cleanup - nothing to do as device is managed externally
|
||||
}
|
||||
|
||||
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
|
||||
// Link address is not used for this endpoint type
|
||||
}
|
||||
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mtu.Store(mtu)
|
||||
}
|
||||
|
||||
func (e *endpoint) SetOnCloseAction(func()) {
|
||||
// No action needed on close
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -36,16 +35,14 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
endpoint.mtu.Store(uint32(mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,8 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,95 +14,30 @@ import (
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
// For Echo Requests, send and wait for response
|
||||
if icmpHdr.Type() == header.ICMPv4Echo {
|
||||
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
|
||||
if !f.hasRawICMPAccess {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
|
||||
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
}
|
||||
}()
|
||||
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu.Load())
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
return 0
|
||||
}
|
||||
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
|
||||
// This is used as a fallback when raw socket access is not available.
|
||||
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
|
||||
icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
|
||||
// Returns the size of the injected packet.
|
||||
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyICMPHdr := header.ICMPv4(replyICMP)
|
||||
replyICMPHdr.SetType(header.ICMPv4EchoReply)
|
||||
replyICMPHdr.SetChecksum(0)
|
||||
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
|
||||
|
||||
return f.injectICMPReply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
|
||||
// Returns the total size of the injected packet, or 0 if injection failed.
|
||||
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(&wq, ep)
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return true
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
|
||||
@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
|
||||
@@ -34,7 +34,7 @@ type RouteRule struct {
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
protoLayer gopacket.LayerType
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
|
||||
@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
protoLayer := d.decoded[1]
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
|
||||
@@ -27,23 +27,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
||||
}
|
||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
buf := bufs[0]
|
||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = size
|
||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
||||
eps[0] = stdEp
|
||||
return 1, nil
|
||||
}
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
|
||||
6
go.mod
6
go.mod
@@ -42,7 +42,7 @@ require (
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
@@ -122,7 +122,7 @@ require (
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
12
go.sum
12
go.sum
@@ -118,8 +118,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -407,8 +407,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
@@ -843,5 +843,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
// LogrusHandler is an slog.Handler that delegates to logrus.
|
||||
// This allows Dex to use the same log format as the rest of NetBird.
|
||||
type LogrusHandler struct {
|
||||
logger *logrus.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
|
||||
func NewLogrusHandler(level slog.Level) *LogrusHandler {
|
||||
logger := logrus.New()
|
||||
formatter.SetTextFormatter(logger)
|
||||
|
||||
// Map slog level to logrus level
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
logger.SetLevel(logrus.ErrorLevel)
|
||||
default:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
}
|
||||
|
||||
return &LogrusHandler{logger: logger}
|
||||
}
|
||||
|
||||
// Enabled reports whether the handler handles records at the given level.
|
||||
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
return h.logger.IsLevelEnabled(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
return h.logger.IsLevelEnabled(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
return h.logger.IsLevelEnabled(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Handle handles the Record.
|
||||
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
fields := make(logrus.Fields)
|
||||
|
||||
// Add pre-set attributes
|
||||
for _, attr := range h.attrs {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
}
|
||||
|
||||
// Add record attributes
|
||||
r.Attrs(func(attr slog.Attr) bool {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.WithFields(fields)
|
||||
|
||||
switch r.Level {
|
||||
case slog.LevelDebug:
|
||||
entry.Debug(r.Message)
|
||||
case slog.LevelInfo:
|
||||
entry.Info(r.Message)
|
||||
case slog.LevelWarn:
|
||||
entry.Warn(r.Message)
|
||||
case slog.LevelError:
|
||||
entry.Error(r.Message)
|
||||
default:
|
||||
entry.Info(r.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithAttrs returns a new Handler with the given attributes added.
|
||||
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
|
||||
copy(newAttrs, h.attrs)
|
||||
copy(newAttrs[len(h.attrs):], attrs)
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: newAttrs,
|
||||
groups: h.groups,
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
|
||||
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
|
||||
newGroups := make([]string, len(h.groups)+1)
|
||||
copy(newGroups, h.groups)
|
||||
newGroups[len(h.groups)] = name
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: h.attrs,
|
||||
groups: newGroups,
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
// Package dex provides an embedded Dex OIDC identity provider.
|
||||
package dex
|
||||
|
||||
@@ -130,21 +132,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
||||
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
||||
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
|
||||
logLevel := slog.LevelWarn
|
||||
if yamlConfig.Logger.Level != "" {
|
||||
switch strings.ToLower(yamlConfig.Logger.Level) {
|
||||
case "debug":
|
||||
logLevel = slog.LevelDebug
|
||||
case "info":
|
||||
logLevel = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
logLevel = slog.LevelWarn
|
||||
case "error":
|
||||
logLevel = slog.LevelError
|
||||
}
|
||||
}
|
||||
logger := slog.New(NewLogrusHandler(logLevel))
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
@@ -792,12 +780,11 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
|
||||
// buildOIDCConnectorConfig creates config for OIDC-based connectors
|
||||
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
oidcConfig := map[string]interface{}{
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
@@ -807,9 +794,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
|
||||
222
idp/dex/stub.go
Normal file
222
idp/dex/stub.go
Normal file
@@ -0,0 +1,222 @@
|
||||
//go:build !cgo
|
||||
|
||||
// Package dex provides an embedded Dex OIDC identity provider.
|
||||
// This stub exists for non-CGO builds where SQLite is unavailable.
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
var errNoCGO = errors.New("embedded IdP requires CGO (SQLite)")
|
||||
|
||||
// Config for simple provider creation
|
||||
type Config struct {
|
||||
Issuer string
|
||||
Port int
|
||||
DataDir string
|
||||
DevMode bool
|
||||
GRPCAddr string
|
||||
}
|
||||
|
||||
// Provider wraps a Dex server
|
||||
type Provider struct{}
|
||||
|
||||
// NewProvider creates a new Provider
|
||||
func NewProvider(_ context.Context, _ *Config) (*Provider, error) { return nil, errNoCGO }
|
||||
|
||||
// NewProviderFromYAML creates a Provider from YAML config
|
||||
func NewProviderFromYAML(_ context.Context, _ *YAMLConfig) (*Provider, error) { return nil, errNoCGO }
|
||||
|
||||
// Start starts the server
|
||||
func (p *Provider) Start(_ context.Context) error { return errNoCGO }
|
||||
|
||||
// Stop stops the server
|
||||
func (p *Provider) Stop(_ context.Context) error { return errNoCGO }
|
||||
|
||||
// EnsureDefaultClients ensures default clients exist
|
||||
func (p *Provider) EnsureDefaultClients(_ context.Context, _, _ []string) error { return errNoCGO }
|
||||
|
||||
// Storage returns the storage
|
||||
func (p *Provider) Storage() storage.Storage { return nil }
|
||||
|
||||
// Handler returns the HTTP handler
|
||||
func (p *Provider) Handler() http.Handler { return nil }
|
||||
|
||||
// CreateUser creates a user
|
||||
func (p *Provider) CreateUser(_ context.Context, _, _, _ string) (string, error) {
|
||||
return "", errNoCGO
|
||||
}
|
||||
|
||||
// GetUser gets a user
|
||||
func (p *Provider) GetUser(_ context.Context, _ string) (storage.Password, error) {
|
||||
return storage.Password{}, errNoCGO
|
||||
}
|
||||
|
||||
// GetUserByID gets a user by ID
|
||||
func (p *Provider) GetUserByID(_ context.Context, _ string) (storage.Password, error) {
|
||||
return storage.Password{}, errNoCGO
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (p *Provider) DeleteUser(_ context.Context, _ string) error { return errNoCGO }
|
||||
|
||||
// ListUsers lists users
|
||||
func (p *Provider) ListUsers(_ context.Context) ([]storage.Password, error) { return nil, errNoCGO }
|
||||
|
||||
// GetRedirectURI returns the redirect URI
|
||||
func (p *Provider) GetRedirectURI() string { return "" }
|
||||
|
||||
// GetIssuer returns the issuer
|
||||
func (p *Provider) GetIssuer() string { return "" }
|
||||
|
||||
// GetTokenEndpoint returns the token endpoint
|
||||
func (p *Provider) GetTokenEndpoint() string { return "" }
|
||||
|
||||
// GetDeviceAuthEndpoint returns the device auth endpoint
|
||||
func (p *Provider) GetDeviceAuthEndpoint() string { return "" }
|
||||
|
||||
// GetAuthorizationEndpoint returns the auth endpoint
|
||||
func (p *Provider) GetAuthorizationEndpoint() string { return "" }
|
||||
|
||||
// GetKeysLocation returns the keys location
|
||||
func (p *Provider) GetKeysLocation() string { return "" }
|
||||
|
||||
// ConnectorConfig for identity provider connectors
|
||||
type ConnectorConfig struct {
|
||||
ID, Name, Type, Issuer, ClientID, ClientSecret string
|
||||
Scopes []string
|
||||
UserIDKey, UserNameKey, EmailKey string
|
||||
InsecureSkipVerify bool
|
||||
AuthorizationURL, TokenURL, UserInfoURL string
|
||||
IdentityProviderType string
|
||||
}
|
||||
|
||||
// CreateConnector creates a connector
|
||||
func (p *Provider) CreateConnector(_ context.Context, _ *ConnectorConfig) (*ConnectorConfig, error) {
|
||||
return nil, errNoCGO
|
||||
}
|
||||
|
||||
// GetConnector gets a connector
|
||||
func (p *Provider) GetConnector(_ context.Context, _ string) (*ConnectorConfig, error) {
|
||||
return nil, errNoCGO
|
||||
}
|
||||
|
||||
// ListConnectors lists connectors
|
||||
func (p *Provider) ListConnectors(_ context.Context) ([]*ConnectorConfig, error) { return nil, errNoCGO }
|
||||
|
||||
// UpdateConnector updates a connector
|
||||
func (p *Provider) UpdateConnector(_ context.Context, _ *ConnectorConfig) error { return errNoCGO }
|
||||
|
||||
// DeleteConnector deletes a connector
|
||||
func (p *Provider) DeleteConnector(_ context.Context, _ string) error { return errNoCGO }
|
||||
|
||||
// EncodeDexUserID encodes a user ID
|
||||
func EncodeDexUserID(_, _ string) string { return "" }
|
||||
|
||||
// DecodeDexUserID decodes a user ID
|
||||
func DecodeDexUserID(_ string) (string, string, error) { return "", "", errNoCGO }
|
||||
|
||||
// YAMLConfig for YAML-based configuration
|
||||
type YAMLConfig struct {
|
||||
Issuer string `yaml:"issuer" json:"issuer"`
|
||||
Storage Storage `yaml:"storage" json:"storage"`
|
||||
Web Web `yaml:"web" json:"web"`
|
||||
GRPC GRPC `yaml:"grpc" json:"grpc"`
|
||||
OAuth2 OAuth2 `yaml:"oauth2" json:"oauth2"`
|
||||
Expiry Expiry `yaml:"expiry" json:"expiry"`
|
||||
Logger Logger `yaml:"logger" json:"logger"`
|
||||
Frontend Frontend `yaml:"frontend" json:"frontend"`
|
||||
StaticConnectors []Connector `yaml:"connectors" json:"connectors"`
|
||||
StaticClients []storage.Client `yaml:"staticClients" json:"staticClients"`
|
||||
EnablePasswordDB bool `yaml:"enablePasswordDB" json:"enablePasswordDB"`
|
||||
StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
|
||||
}
|
||||
|
||||
// Validate validates config
|
||||
func (c *YAMLConfig) Validate() error { return errNoCGO }
|
||||
|
||||
// ToServerConfig converts to server config
|
||||
func (c *YAMLConfig) ToServerConfig(_ storage.Storage, _ *slog.Logger) server.Config {
|
||||
return server.Config{}
|
||||
}
|
||||
|
||||
// GetRefreshTokenPolicy gets refresh policy
|
||||
func (c *YAMLConfig) GetRefreshTokenPolicy(_ *slog.Logger) (*server.RefreshTokenPolicy, error) {
|
||||
return nil, errNoCGO
|
||||
}
|
||||
|
||||
// LoadConfig loads config from file
|
||||
func LoadConfig(_ string) (*YAMLConfig, error) { return nil, errNoCGO }
|
||||
|
||||
// Web config
|
||||
type Web struct {
|
||||
HTTP, HTTPS string
|
||||
AllowedOrigins []string
|
||||
AllowedHeaders []string
|
||||
}
|
||||
|
||||
// GRPC config
|
||||
type GRPC struct{ Addr, TLSCert, TLSKey, TLSClientCA string }
|
||||
|
||||
// OAuth2 config
|
||||
type OAuth2 struct {
|
||||
SkipApprovalScreen, AlwaysShowLoginScreen bool
|
||||
PasswordConnector string
|
||||
ResponseTypes, GrantTypes []string
|
||||
}
|
||||
|
||||
// Expiry config
|
||||
type Expiry struct {
|
||||
SigningKeys, IDTokens, AuthRequests, DeviceRequests string
|
||||
RefreshTokens RefreshTokensExpiry
|
||||
}
|
||||
|
||||
// RefreshTokensExpiry config
|
||||
type RefreshTokensExpiry struct {
|
||||
ReuseInterval, ValidIfNotUsedFor, AbsoluteLifetime string
|
||||
DisableRotation bool
|
||||
}
|
||||
|
||||
// Logger config
|
||||
type Logger struct{ Level, Format string }
|
||||
|
||||
// Frontend config
|
||||
type Frontend struct {
|
||||
Dir, Theme, Issuer, LogoURL string
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
// Storage config
|
||||
type Storage struct {
|
||||
Type string
|
||||
Config map[string]interface{}
|
||||
}
|
||||
|
||||
// OpenStorage opens storage
|
||||
func (s *Storage) OpenStorage(_ *slog.Logger) (storage.Storage, error) { return nil, errNoCGO }
|
||||
|
||||
// Password type
|
||||
type Password storage.Password
|
||||
|
||||
// Connector config
|
||||
type Connector struct {
|
||||
Type, Name, ID string
|
||||
Config map[string]interface{}
|
||||
}
|
||||
|
||||
// ToStorageConnector converts to storage connector
|
||||
func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
||||
return storage.Connector{}, errNoCGO
|
||||
}
|
||||
|
||||
// StorageConfig interface
|
||||
type StorageConfig interface {
|
||||
Open(logger *slog.Logger) (storage.Storage, error)
|
||||
}
|
||||
@@ -169,7 +169,8 @@ init_environment() {
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<EOF
|
||||
{
|
||||
{
|
||||
debug
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
@@ -270,7 +271,7 @@ AUTH_CLIENT_ID=netbird-dashboard
|
||||
AUTH_CLIENT_SECRET=
|
||||
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||
USE_AUTH0=false
|
||||
AUTH_SUPPORTED_SCOPES=openid profile email groups
|
||||
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
|
||||
AUTH_REDIRECT_URI=/nb-auth
|
||||
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
# SSL
|
||||
|
||||
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
|
||||
applyCommandLineOverrides(loadedConfig)
|
||||
|
||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||
err := applyEmbeddedIdPConfig(loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
|
||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -190,8 +190,10 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
||||
userDeleteFromIDPEnabled = true
|
||||
|
||||
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
|
||||
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||
// Ensure HttpConfig exists
|
||||
if cfg.HttpConfig == nil {
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
|
||||
// Set storage defaults based on Datadir
|
||||
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||
@@ -203,22 +205,35 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
if cfg.HttpConfig != nil {
|
||||
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
|
||||
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
|
||||
} else {
|
||||
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
// Set AuthIssuer from EmbeddedIdP issuer
|
||||
if cfg.HttpConfig.AuthIssuer == "" {
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
}
|
||||
|
||||
// Set HttpConfig values from EmbeddedIdP
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
// Set AuthAudience to the dashboard client ID
|
||||
if cfg.HttpConfig.AuthAudience == "" {
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
}
|
||||
|
||||
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
|
||||
if cfg.HttpConfig.AuthUserIDClaim == "" {
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
}
|
||||
|
||||
// Set AuthKeysLocation to the JWKS endpoint
|
||||
if cfg.HttpConfig.AuthKeysLocation == "" {
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
}
|
||||
|
||||
// Set OIDCConfigEndpoint to the discovery endpoint
|
||||
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
}
|
||||
|
||||
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
|
||||
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -226,12 +241,7 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
|
||||
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
|
||||
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -414,7 +414,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -475,7 +475,7 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
account := c.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
@@ -547,12 +547,12 @@ func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -102,9 +102,6 @@ type HttpServerConfig struct {
|
||||
CertKey string
|
||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||
AuthAudience string
|
||||
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
|
||||
// Used only in conjunction with EmbeddedIdP
|
||||
CLIAuthAudience string
|
||||
// AuthIssuer identifies principal that issued the JWT
|
||||
AuthIssuer string
|
||||
// AuthUserIDClaim is the name of the claim that used as user ID
|
||||
|
||||
@@ -68,8 +68,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
keysLocation = oauthProvider.GetKeysLocation()
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
|
||||
@@ -129,11 +129,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
|
||||
idpManager = s.Config.IdpManagerConfig.ManagerType
|
||||
}
|
||||
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpManager = metrics.EmbeddedType
|
||||
}
|
||||
|
||||
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
|
||||
go metricsWorker.Run(srvCtx)
|
||||
}
|
||||
|
||||
@@ -428,13 +428,9 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
|
||||
keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json"
|
||||
}
|
||||
|
||||
audience := config.AuthAudience
|
||||
if config.CLIAuthAudience != "" {
|
||||
audience = config.CLIAuthAudience
|
||||
}
|
||||
return &proto.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Audience: config.AuthAudience,
|
||||
KeysLocation: keysLocation,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3465,11 +3465,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -893,7 +893,6 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
@@ -2,13 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -23,69 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// oidcProviderJSON represents the OpenID Connect discovery document
|
||||
type oidcProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
|
||||
// and verifying that the returned issuer matches the configured one.
|
||||
func validateOIDCIssuer(ctx context.Context, issuer string) error {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
|
||||
}
|
||||
|
||||
var p oidcProviderJSON
|
||||
if err := json.Unmarshal(body, &p); err != nil {
|
||||
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if p.Issuer != issuer {
|
||||
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateIdentityProviderConfig validates the identity provider configuration including
|
||||
// basic validation and OIDC issuer verification.
|
||||
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
// Validate the issuer by calling the OIDC discovery endpoint
|
||||
if idpConfig.Issuer != "" {
|
||||
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
@@ -151,8 +82,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
@@ -188,8 +119,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
|
||||
@@ -2,10 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -204,109 +200,3 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "name is required")
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func() *httptest.Server
|
||||
expectedErr error
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "issuer mismatch",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerMismatch,
|
||||
expectedErrMsg: "does not match",
|
||||
},
|
||||
{
|
||||
name: "server returns non-200 status",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = w.Write([]byte("not found"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "404",
|
||||
},
|
||||
{
|
||||
name: "server returns invalid JSON",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("invalid json"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "failed to decode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
|
||||
if tt.expectedErrMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.expectedErrMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_Success(t *testing.T) {
|
||||
// Create a server that returns its own URL as the issuer
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
|
||||
// Use a URL that will definitely fail to connect
|
||||
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
|
||||
// Test that trailing slashes are handled correctly
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
// Return issuer without trailing slash
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pass issuer with trailing slash
|
||||
err := validateOIDCIssuer(context.Background(), server.URL+"/")
|
||||
// This should fail because the issuer returned doesn't have trailing slash
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/google/uuid"
|
||||
@@ -20,7 +19,7 @@ const (
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email"
|
||||
defaultScopes = "openid profile email offline_access"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
@@ -28,11 +27,8 @@ const (
|
||||
type EmbeddedIdPConfig struct {
|
||||
// Enabled indicates whether the embedded IDP is enabled
|
||||
Enabled bool
|
||||
// Issuer is the OIDC issuer URL (e.g., "https://management.netbird.io/oauth2")
|
||||
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
|
||||
Issuer string
|
||||
// LocalAddress is the management server's local listen address (e.g., ":8080" or "localhost:8080")
|
||||
// Used for internal JWT validation to avoid external network calls
|
||||
LocalAddress string
|
||||
// Storage configuration for the IdP database
|
||||
Storage EmbeddedStorageConfig
|
||||
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
||||
@@ -150,12 +146,7 @@ var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
|
||||
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
|
||||
type OAuthConfigProvider interface {
|
||||
GetIssuer() string
|
||||
// GetKeysLocation returns the public JWKS endpoint URL (uses external issuer URL)
|
||||
GetKeysLocation() string
|
||||
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal use.
|
||||
// Management server has embedded Dex and can validate tokens via localhost,
|
||||
// avoiding external network calls and DNS resolution issues during startup.
|
||||
GetLocalKeysLocation() string
|
||||
GetClientIDs() []string
|
||||
GetUserIDClaim() string
|
||||
GetTokenEndpoint() string
|
||||
@@ -509,22 +500,6 @@ func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
||||
return m.provider.GetKeysLocation()
|
||||
}
|
||||
|
||||
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal token validation.
|
||||
// Uses the LocalAddress from config (management server's listen address) since embedded Dex
|
||||
// is served by the management HTTP server, not a standalone Dex server.
|
||||
func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
|
||||
addr := m.config.LocalAddress
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
// Construct localhost URL from listen address
|
||||
// addr is in format ":port" or "host:port" or "localhost:port"
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
return fmt.Sprintf("http://localhost%s/oauth2/keys", addr)
|
||||
}
|
||||
return fmt.Sprintf("http://%s/oauth2/keys", addr)
|
||||
}
|
||||
|
||||
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
||||
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
return []string{staticClientDashboard, staticClientCLI}
|
||||
|
||||
@@ -247,61 +247,3 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
|
||||
t.Logf(" Raw UUID: %s", rawUserID)
|
||||
t.Logf(" Connector: %s", connectorID)
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localAddress string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "localhost with port",
|
||||
localAddress: "localhost:8080",
|
||||
expected: "http://localhost:8080/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "localhost with https port",
|
||||
localAddress: "localhost:443",
|
||||
expected: "http://localhost:443/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "port only format",
|
||||
localAddress: ":8080",
|
||||
expected: "http://localhost:8080/oauth2/keys",
|
||||
},
|
||||
{
|
||||
name: "empty address",
|
||||
localAddress: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
LocalAddress: tt.localAddress,
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex-"+tt.name+".db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
result := manager.GetLocalKeysLocation()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -29,7 +28,6 @@ const (
|
||||
defaultPushInterval = 12 * time.Hour
|
||||
// requestTimeout http request timeout
|
||||
requestTimeout = 45 * time.Second
|
||||
EmbeddedType = "embedded"
|
||||
)
|
||||
|
||||
type getTokenResponse struct {
|
||||
@@ -208,8 +206,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
peerActiveVersions []string
|
||||
osUIClients map[string]int
|
||||
rosenpassEnabled int
|
||||
localUsers int
|
||||
idpUsers int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -270,16 +266,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
serviceUsers++
|
||||
} else {
|
||||
users++
|
||||
if w.idpManager == EmbeddedType {
|
||||
_, idpID, err := dex.DecodeDexUserID(user.Id)
|
||||
if err == nil {
|
||||
if idpID == "local" {
|
||||
localUsers++
|
||||
} else {
|
||||
idpUsers++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pats += len(user.PATs)
|
||||
}
|
||||
@@ -367,8 +353,6 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["idp_manager"] = w.idpManager
|
||||
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
|
||||
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
||||
metricsProperties["local_users_count"] = localUsers
|
||||
metricsProperties["idp_users_count"] = idpUsers
|
||||
|
||||
for protocol, count := range rulesProtocol {
|
||||
metricsProperties["rules_protocol_"+protocol] = count
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -26,8 +25,6 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
||||
|
||||
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
localUserID := dex.EncodeDexUserID("10", "local")
|
||||
idpUserID := dex.EncodeDexUserID("20", "zitadel")
|
||||
return []*types.Account{
|
||||
{
|
||||
Id: "1",
|
||||
@@ -101,14 +98,12 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"1": {
|
||||
Id: "1",
|
||||
IsServiceUser: true,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
localUserID: {
|
||||
Id: localUserID,
|
||||
"2": {
|
||||
IsServiceUser: false,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
@@ -167,14 +162,12 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"1": {
|
||||
Id: "1",
|
||||
IsServiceUser: true,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
idpUserID: {
|
||||
Id: idpUserID,
|
||||
"2": {
|
||||
IsServiceUser: false,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
@@ -221,7 +214,6 @@ func TestGenerateProperties(t *testing.T) {
|
||||
worker := Worker{
|
||||
dataSource: ds,
|
||||
connManager: ds,
|
||||
idpManager: EmbeddedType,
|
||||
}
|
||||
|
||||
properties := worker.generateProperties(context.Background())
|
||||
@@ -335,10 +327,4 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
|
||||
}
|
||||
|
||||
if properties["local_users_count"] != 1 {
|
||||
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
|
||||
}
|
||||
if properties["idp_users_count"] != 1 {
|
||||
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,11 +404,10 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
|
||||
if dialect == "mysql" {
|
||||
var withLength []string
|
||||
for _, col := range columns {
|
||||
quotedCol := fmt.Sprintf("`%s`", col)
|
||||
if col == "ip" || col == "dns_label" || col == "key" {
|
||||
withLength = append(withLength, fmt.Sprintf("%s(64)", quotedCol))
|
||||
if col == "ip" || col == "dns_label" {
|
||||
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
|
||||
} else {
|
||||
withLength = append(withLength, quotedCol)
|
||||
withLength = append(withLength, col)
|
||||
}
|
||||
}
|
||||
columnClause = strings.Join(withLength, ", ")
|
||||
@@ -488,54 +487,3 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
|
||||
if !db.Migrator().HasTable("peers") {
|
||||
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")
|
||||
return nil
|
||||
}
|
||||
|
||||
keyColumn := GetColumnName(db, "key")
|
||||
|
||||
var duplicates []struct {
|
||||
Key string
|
||||
Count int64
|
||||
}
|
||||
|
||||
if err := db.Table("peers").
|
||||
Select(keyColumn + ", COUNT(*) as count").
|
||||
Group(keyColumn).
|
||||
Having("COUNT(*) > 1").
|
||||
Find(&duplicates).Error; err != nil {
|
||||
return fmt.Errorf("find duplicate keys: %w", err)
|
||||
}
|
||||
|
||||
if len(duplicates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Warnf("Found %d duplicate peer keys, cleaning up", len(duplicates))
|
||||
|
||||
for _, dup := range duplicates {
|
||||
var peerIDs []string
|
||||
if err := db.Table("peers").
|
||||
Select("id").
|
||||
Where(keyColumn+" = ?", dup.Key).
|
||||
Order("peer_status_last_seen DESC").
|
||||
Pluck("id", &peerIDs).Error; err != nil {
|
||||
return fmt.Errorf("get peers for key: %w", err)
|
||||
}
|
||||
|
||||
if len(peerIDs) <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
idsToDelete := peerIDs[1:]
|
||||
|
||||
if err := db.Table("peers").Where("id IN ?", idsToDelete).Delete(nil).Error; err != nil {
|
||||
return fmt.Errorf("delete duplicate peers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -340,104 +340,3 @@ func TestCreateIndexIfExists(t *testing.T) {
|
||||
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||
assert.True(t, exist, "Should have the index")
|
||||
}
|
||||
|
||||
type testPeer struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Key string `gorm:"index"`
|
||||
PeerStatusLastSeen time.Time
|
||||
PeerStatusConnected bool
|
||||
}
|
||||
|
||||
func (testPeer) TableName() string {
|
||||
return "peers"
|
||||
}
|
||||
|
||||
func setupPeerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db := setupDatabase(t)
|
||||
_ = db.Migrator().DropTable(&testPeer{})
|
||||
err := db.AutoMigrate(&testPeer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRemoveDuplicatePeerKeys_NoDuplicates(t *testing.T) {
|
||||
db := setupPeerTestDB(t)
|
||||
|
||||
now := time.Now()
|
||||
peers := []testPeer{
|
||||
{ID: "peer1", Key: "key1", PeerStatusLastSeen: now},
|
||||
{ID: "peer2", Key: "key2", PeerStatusLastSeen: now},
|
||||
{ID: "peer3", Key: "key3", PeerStatusLastSeen: now},
|
||||
}
|
||||
|
||||
for _, p := range peers {
|
||||
err := db.Create(&p).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testPeer{}).Count(&count)
|
||||
assert.Equal(t, int64(len(peers)), count, "All peers should remain when no duplicates")
|
||||
}
|
||||
|
||||
func TestRemoveDuplicatePeerKeys_WithDuplicates(t *testing.T) {
|
||||
db := setupPeerTestDB(t)
|
||||
|
||||
now := time.Now()
|
||||
peers := []testPeer{
|
||||
{ID: "peer1", Key: "key1", PeerStatusLastSeen: now.Add(-2 * time.Hour)},
|
||||
{ID: "peer2", Key: "key1", PeerStatusLastSeen: now.Add(-1 * time.Hour)},
|
||||
{ID: "peer3", Key: "key1", PeerStatusLastSeen: now},
|
||||
{ID: "peer4", Key: "key2", PeerStatusLastSeen: now},
|
||||
{ID: "peer5", Key: "key3", PeerStatusLastSeen: now.Add(-1 * time.Hour)},
|
||||
{ID: "peer6", Key: "key3", PeerStatusLastSeen: now},
|
||||
}
|
||||
|
||||
for _, p := range peers {
|
||||
err := db.Create(&p).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testPeer{}).Count(&count)
|
||||
assert.Equal(t, int64(3), count, "Should have 3 peers after removing duplicates")
|
||||
|
||||
var remainingPeers []testPeer
|
||||
err = db.Find(&remainingPeers).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
remainingIDs := make(map[string]bool)
|
||||
for _, p := range remainingPeers {
|
||||
remainingIDs[p.ID] = true
|
||||
}
|
||||
|
||||
assert.True(t, remainingIDs["peer3"], "peer3 should remain (most recent for key1)")
|
||||
assert.True(t, remainingIDs["peer4"], "peer4 should remain (only peer for key2)")
|
||||
assert.True(t, remainingIDs["peer6"], "peer6 should remain (most recent for key3)")
|
||||
|
||||
assert.False(t, remainingIDs["peer1"], "peer1 should be deleted (older duplicate)")
|
||||
assert.False(t, remainingIDs["peer2"], "peer2 should be deleted (older duplicate)")
|
||||
assert.False(t, remainingIDs["peer5"], "peer5 should be deleted (older duplicate)")
|
||||
}
|
||||
|
||||
func TestRemoveDuplicatePeerKeys_EmptyTable(t *testing.T) {
|
||||
db := setupPeerTestDB(t)
|
||||
|
||||
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
|
||||
require.NoError(t, err, "Should not fail on empty table")
|
||||
}
|
||||
|
||||
func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
_ = db.Migrator().DropTable(&testPeer{})
|
||||
|
||||
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
|
||||
require.NoError(t, err, "Should not fail when table does not exist")
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Peer struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// WireGuard public key
|
||||
Key string // uniqueness index (check migrations)
|
||||
Key string `gorm:"index"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||
// Meta is a Peer system meta data
|
||||
|
||||
@@ -2129,14 +2129,12 @@ func Test_DeletePeer(t *testing.T) {
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
AccountID: accountID,
|
||||
Key: "key1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
DNSLabel: "peer1.test",
|
||||
},
|
||||
"peer2": {
|
||||
ID: "peer2",
|
||||
AccountID: accountID,
|
||||
Key: "key2",
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
DNSLabel: "peer2.test",
|
||||
},
|
||||
|
||||
@@ -3029,9 +3029,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
return &SqlStore{
|
||||
db: tx,
|
||||
storeEngine: s.storeEngine,
|
||||
fieldEncrypt: s.fieldEncrypt,
|
||||
db: tx,
|
||||
storeEngine: s.storeEngine,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -968,7 +968,6 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key1",
|
||||
DNSLabel: "peer1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
@@ -983,7 +982,6 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key2",
|
||||
DNSLabel: "peer1-1",
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
@@ -1011,7 +1009,6 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key1",
|
||||
DNSLabel: "peer1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
@@ -1025,7 +1022,6 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key2",
|
||||
DNSLabel: "peer1-1",
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
@@ -1052,7 +1048,6 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key1",
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
@@ -1061,7 +1056,6 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) {
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key2",
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
@@ -1079,7 +1073,6 @@ func Test_AddPeerWithSameIP(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
@@ -1088,7 +1081,6 @@ func Test_AddPeerWithSameIP(t *testing.T) {
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
Key: "key2",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
@@ -3704,7 +3696,6 @@ func BenchmarkGetAccountPeers(b *testing.B) {
|
||||
peer := &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
AccountID: accountID,
|
||||
Key: fmt.Sprintf("key-%d", i),
|
||||
DNSLabel: fmt.Sprintf("peer%d.example.com", i),
|
||||
IP: intToIPv4(uint32(i)),
|
||||
}
|
||||
|
||||
@@ -350,13 +350,8 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// migratePostAuto migrates the SQLite database to the latest schema
|
||||
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||
migrations := getMigrationsPostAuto(ctx)
|
||||
|
||||
@@ -386,12 +381,6 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
}
|
||||
})
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[nbpeer.Peer](ctx, db, "idx_peers_key")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
|
||||
4
management/server/testdata/store.sql
vendored
4
management/server/testdata/store.sql
vendored
@@ -18,7 +18,7 @@ CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,
|
||||
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_peers_account_id_ip` ON `peers`(`account_id`,`ip`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
@@ -54,4 +54,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
|
||||
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
||||
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
||||
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Do=','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||
|
||||
@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
|
||||
@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
@@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
|
||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
2
management/server/testdata/storev1.sql
vendored
2
management/server/testdata/storev1.sql
vendored
@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
|
||||
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
|
||||
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
|
||||
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
|
||||
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
|
||||
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
|
||||
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
|
||||
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
|
||||
|
||||
@@ -32,13 +32,13 @@ func (h *Holder) AddAccount(account *Account) {
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(ctx, id)
|
||||
account, err := accGetter(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -7,14 +7,12 @@ import (
|
||||
|
||||
// Identity provider validation errors
|
||||
var (
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable")
|
||||
ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
)
|
||||
|
||||
// IdentityProviderType is the type of identity provider
|
||||
|
||||
Reference in New Issue
Block a user